Python+Tensorflow机器学习实战
上QQ阅读APP看书,第一时间看更新

3.1 加载数据

在TensorFlow中加载数据的方式一共有三种:预加载数据、填充数据以及从CSV文件读取数据。

3.1.1 预加载数据

预加载数据就是在TensorFlow图中定义常量或变量来保存所有数据,例如:

  
     01 a=tf.constant([1,2])
     02 b=tf.constant([3,4])
     03 y=tf.add(a,b)

因为常量会直接存储在数据流图的数据结构中,所以在训练过程中这个结构体可能会被复制多次,从而导致消耗大量内存。

3.1.2 填充数据

TensorFlow提供的数据填充机制允许在TensorFlow的计算图训练过程中,将数据填充到任意张量中。可通过会话的run()函数中的feed_dict参数获取数据,例如:

当数据量大时,填充数据的方式也存在消耗内存的缺点。

3.1.3 从CSV文件读取数据

TensorFlow从文件中读取数据的方式主要有两种,一种是直接从原始文件中读取数据,另一种是将原始文件格式转换为TensorFlow定义的TFRecords格式后再进行读取。TensorFlow提供了如下对应方法。

class tf.TextLineReader:读取文件中的一行文本,返回两个Tensor对象,如(key, value)。

class tf.WholeFileReader:读取整个文件,返回两个值,分别是文件名称和文件内容。

class tf.IdentityReader:以key和value的形式输出一个work队列。

class tf.FixedLengthRecordReader:从二进制文件中读取固定长度的记录。

class tf.TFRecordReader:读取TFRecords格式的文件。

对于从文件中读取数据,首先使用读取器将数据读取到队列中,然后从队列中获取数据并进行处理。下面以读取CSV格式的文件来讲解TensorFlow从源文件中直接读取数据的具体过程。

1.创建队列

TensorFlow提供了队列的创建方法:

其中,string_tensor是读取的文件名列表,num_epochs是文件的训练次数,shuffle表示是否对文件进行乱序处理。需要注意的是,返回队列的队列管理器与文件读取器的线程是分开的。

从airline.csv文件中读取数据,创建队列的具体实现如下:

2.创建读取器获取数据

在TensorFlow中,针对不同的文件格式,提供了不同的文件读取器。在文件读取器中提供了read()方法,用于获取文件内容和文件的表征值key,从而将内容转换为张量,用于TensorFlow的解析。

例如在图3.1中,airline.csv文件中的每行数据都包含6个浮点数据。

图3.1 airline.csv文件

对于该文件中数据的读取,分为创建读取器、获取数据和解析数据三个步骤。

对于CSV文件中数据的读取,使用TextLineReader来创建读取器。

使用读取器的read()方法来获取数据。

对于数据的解析,需要根据读取的CSV文件的数据格式和数据类型来定义格式。然后使用decode_csv()方法来解析内容。具体实现如下:

其中,第03行根据读取文件的数据类型构造对应的数据类型,而且必须是list形式。

第04行将每一行读取的内容(value)按照数据类型(record_defaults)解析到张量col1、col2、col3、col4、col5和col6中。

3.处理数据

在此对获取的数据进行打印输出。需要注意的是,由于队列管理器与文件阅读器的线程是相互独立的,因此需要先启用队列,再使用线程协调器来管理这两个线程。具体实现如下:

运行上述代码,结果如图3.2所示。

图3.2 输出结果

将输出结果与源文件airline.csv进行对比,可以很明显地看到差异。

3.1.4 读取TFRecords数据

在机器学习中,处理的数据量都非常巨大,常用的数据读取方式一般都会存在内存占用过高的问题。TensorFlow针对该问题进行了优化,定义了TFRecords格式文件。

TFRecords是一种二进制文件,能更好地利用内存,更方便地进行复制和移动,并且不需要单独标记文件,可以使TensorFlow的数据集更容易与网络应用架构相匹配。采用这种方式读取数据分为如下两个步骤。

①把样本数据转换为TFRecords二进制文件。

②读取TFRecords格式文件。

前一章介绍了MNIST数据集,本节将继续以MNIST数据集为数据源,将其转为TFRecords格式文件,然后读取TFRecords格式文件中的几张图片。

1.生成TFRecords文件

TFRecords文件中的数据是通过tf.train.Example协议缓冲区的格式存储的。生成TFRecords文件就是将数据填入tf.train.Example协议缓存区,然后将该协议缓冲区序列化为一个字符串,写入TFRecords文件。

tensorflow\core\example目录的example.proto和feature.proto文件中给出了tf.train.Example协议缓冲区的定义:

从上述代码可以看到,tf.train.Example协议缓冲区的数据结构相对简洁,可以理解成属性名和属性值的对应关系表。其中,属性名是一个字符串,属性值可以是字符串(BytesList)、实数列表(FloatList)或整数列表(Int64List)。

因此,将数据填入tf.train.Example协议缓冲区的过程,就是构建tf.train.Example数据结构的过程。对于MNIST数据集中的数据,我们构建的数据结构中仅保存两个属性:标签和图像。数据结构的具体实现如下:

通过TFRecordWriter()方法,将缓冲区的数据序列化后写入TFRecords文件。生成TFRecords文件的具体实现如下:

其中,对于从文件中读取的数据,第10~15行定义了将它们转换为与协议缓冲区匹配的整数列表和字符串类型。

第19~22行读取MNIST数据集文件,并获取与数据对应的图像和标签。

运行上述代码,将在代码所在文件夹中生成output.tfrecords文件。使用二进制文件编辑器打开该文件,显示结果如图3.3所示。

图3.3 output.tfrecords文件

2.读取TFRecords文件

读取TFRecords文件就是使用队列读取TFRecords文件中的数据,可以分为以下两个步骤。

①获取一个协议缓冲区,解析对应属性,转换为张量。

②将张量作为输入进行训练处理。

首先,使用队列从TFRecords文件中读取数据,然后使用tf.TFRecordReader的tf.parse_single_example操作将tf.train.Example协议缓冲区解析为张量,具体实现如下:

由于输入图像的处理可以是无序的,因此使用tf.train.shuffle_batch生成随机队列以进行多线程的样本处理。具体实现如下:

最后,将获取的文件张量batch在训练中进行处理。在此将获取的张量转换为图片进行保存。具体实现如下:

运行上述代码,在代码所在文件夹中生成了对应的5个图片文件,结果显示如图3.4所示。

图3.4 图片文件