![TensorFlow从零开始学](https://wfqqreader-1252317822.image.myqcloud.com/cover/344/31186344/b_31186344.jpg)
2.3 TensorFlow2.0的使用
2.3.1 “tf.data”API
除GPU和TPU等硬件加速设备外,高效的数据输入管道也可以很大程度地提升模型性能,减少模型训练所需要的时间。数据输入管道本质是一个ELT(Extract、Transform和Load)过程:
●Extract:从硬盘中读取数据(可以是本地的,也可以是云端的)。
●Transform:数据的预处理(如数据清洗、格式转换等)。
●Load:将处理好的数据加载到计算设备(例如CPU、GPU及TPU等)。
数据输入管道一般使用CPU来执行ELT过程,GPU等其他硬件加速设备则负责模型的训练,ELT过程和模型的训练并行执行,从而提高模型训练的效率。另外ELT过程的各个步骤也都可以进行相应的优化,例如并行地读取和处理数据等。在TensorFlow中可以使用“tf.data”API来构建这样的数据输入管道。
这里使用的是一个花朵图片的数据集,如图2-13所示,除一个License文件外,主要是五个分别存放着对应类别花朵图片的文件夹,其中“daisy(雏菊)”文件夹中有633张图片,“dandelion(蒲公英)”文件夹中有898张图片,“roses(玫瑰)”文件夹中有641张图片,“sunflowers(向日葵)”文件夹中有699张图片,“tulips(郁金香)”文件夹中有799张图片。
![](https://epubservercos.yuewen.com/C60F01/16896237405619506/epubprivate/OEBPS/Images/txt002_19.jpg?sign=1738836529-qHKi6L7TyXNf2MEI8GKXIc2Zj6soTvgO-0-679f6b65549035fc1d72dd46d0391dbd)
图2-13 解压后的数据集
接下来开始实现代码,导入需要使用的包:
![](https://epubservercos.yuewen.com/C60F01/16896237405619506/epubprivate/OEBPS/Images/txt002_20.jpg?sign=1738836529-AnE31X4GdES2TwrE2X6rTGRGZai1O5Vf-0-a265772621060a677a795c6ab034cd4c)
pathlib提供了一组用于处理文件系统路径的类。导入需要的包后,可以先检查一下TensorFlow的版本:
![](https://epubservercos.yuewen.com/C60F01/16896237405619506/epubprivate/OEBPS/Images/txt002_21.jpg?sign=1738836529-5cMT1GSXK5ypIs5pU9UuRXzXjulPQN3d-0-5eea751d192513d9bfd7c8e9a009669b)
获取所有图片样本文件的路径:
![](https://epubservercos.yuewen.com/C60F01/16896237405619506/epubprivate/OEBPS/Images/txt002_22.jpg?sign=1738836529-zU1abte2MIJECcSZLxnyv0zM5yHStriO-0-6d9628cd12aac41d3b30ab2276a9b810)
输出结果如图2-14所示。
![](https://epubservercos.yuewen.com/C60F01/16896237405619506/epubprivate/OEBPS/Images/txt002_24.jpg?sign=1738836529-WbXYpRuXKBOm4VoVHMijcbLZEFIFbiKQ-0-33a8c7728bbe221822d16edac2969112)
图2-14 文件路径输出结果
接下来统计图片的类别,并给每一个类别分配一个类标:
![](https://epubservercos.yuewen.com/C60F01/16896237405619506/epubprivate/OEBPS/Images/txt002_25.jpg?sign=1738836529-L6cDd3lzpSToitIHfGOJfjKaYmM3AvFY-0-d6178f0fbdf065ee1fb0fe4569a4f7e8)
输出结果如图2-15所示,daisy(雏菊)、dandelion(蒲公英)、roses(玫瑰)、sunflowers(向日葵)和tulips(郁金香)的类标分别为0、1、2、3和5。
![](https://epubservercos.yuewen.com/C60F01/16896237405619506/epubprivate/OEBPS/Images/txt002_26.jpg?sign=1738836529-VhGAZ0d5XPDkgOtsBkojm2EpDS4oZx2k-0-f9bb8690e01b7b25972b8803f566b821)
图2-15 图片类标的输出结果
处理完类标之后,接下来需要对图片本身做一些处理,这里定义一个函数,用来加载和预处理图片数据。
![](https://epubservercos.yuewen.com/C60F01/16896237405619506/epubprivate/OEBPS/Images/txt002_27.jpg?sign=1738836529-8JjgEXI8WiResuxCz3toGSyV4LtNsZnc-0-3e5a959fcbae5620ae97f8fc589135ae)
完成对类标和图片数据的预处理之后,使用“tf.data.Dataset”来构建和管理数据集:
![](https://epubservercos.yuewen.com/C60F01/16896237405619506/epubprivate/OEBPS/Images/txt002_28.jpg?sign=1738836529-KExOdtDiQPUJMyXf5YXO955cqBB6TJCc-0-dd07192dc8a5d2b5715d37e5a43d1308)
输出结果如图2-16所示。
![](https://epubservercos.yuewen.com/C60F01/16896237405619506/epubprivate/OEBPS/Images/txt002_29.jpg?sign=1738836529-h84YpYoSymJJBivVUWKOp6J4IVhfL8F2-0-3935982e9e300ad59a4cab1232bf451a)
图2-16 构建的数据集
在第35行和第41行代码中,“from_tensor_slices”方法使用张量的切片元素构建数据集,“tf.data.Dataset”类还提供了“from_tensor”,直接使用单个张量来构建数据集,以及“from_generator”方法使用生成器生成的元素来构建数据集。
在第39行代码中,我们使用了“tf.data.Dataset”的“map”方法,该方法允许自定义一个函数,该函数会将原数据集中的元素依次进行处理,并将处理后的数据作为新的数据集,处理前和处理后的数据顺序不变。例如这里我们自己定义了一个“load_and_preprocess_image”函数,将“path_ds”中的图片路径转换成了经过预处理的图像数据,并保存在了“image_ds”中。
最后使用“tf.data.Dataset”的“zip”方法将图片数据和类标数据压缩成“(图片,类标)”对。数据集中的部分数据可视化结果如图2-17所示。
![](https://epubservercos.yuewen.com/C60F01/16896237405619506/epubprivate/OEBPS/Images/txt002_30.jpg?sign=1738836529-JkHLQoIeLd3FWWTLgvWtEeD8cBcF7nop-0-f297466e2cff2ceb88a5bd2d0b68e38b)
图2-17 数据集中部分数据的可视化
![](https://epubservercos.yuewen.com/C60F01/16896237405619506/epubprivate/OEBPS/Images/txt002_31.jpg?sign=1738836529-8Kc44B8ySoVjBA7ZEgV5r2pPyaasH33V-0-25f0afcfb1c841da4eaa23b37121a629)
接下来用创建的数据集训练一个分类模型,这个例子的目的是让读者了解如何使用我们创建的数据集,简单起见,直接使用“tf.keras.applications”包中训练好的模型,并将其迁移到我们的花朵分类任务上来。这里使用的是“MobileNetV2”模型。
![](https://epubservercos.yuewen.com/C60F01/16896237405619506/epubprivate/OEBPS/Images/txt002_33.jpg?sign=1738836529-09Sot5rYswfXgvjwITHUJtKcijcpluWM-0-3981c9d4206aca40597118748287ce0c)
当我们执行第59行代码后,训练好的“MobileNetV2”模型会被下载到本地,该模型是在ImageNet数据集上训练的。因为我们想把该训练好的模型迁移到花朵分类问题中来,所以第61行代码将该模型的参数设置为不可训练和更新。
接下来打乱一下数据集,以及定义好训练过程中每个批次(Batch)数据的大小。
![](https://epubservercos.yuewen.com/C60F01/16896237405619506/epubprivate/OEBPS/Images/txt002_34.jpg?sign=1738836529-kd92QiVqvr67xsPtTSY1VQ2tLMqy7ZW0-0-a2026f366145bd4a2ef1489fa5968a7d)
在第64行代码中,我们使用“tf.data.Dataset”类的“shuffle”方法将数据集进行打乱。第66行代码使用“repeat”方法让数据集可以重复获取,通常情况下,若一个训练回合(Epoch)只对完整的数据集训练一遍,则可以不需要设置“repeat”。“repeat”方法可以设置参数,例如“ds.repeat(2)”是让数据集可以重复获取两遍,即在一个训练回合中,可以使用两遍数据集。若不加参数的话,则默认可以无限次重复获取数据集。
第68、69行代码设置了训练过程中一个批次数据的大小。在第71行代码中,我们使用“tf.data.Dataset.prefetch”方法让ELT过程中的“数据准备和预处理(EL)”和“数据消耗(T)”过程并行。
由于“MobileNetV2”模型接收的输入数据是归一化在[-1,1]之间的数据,而在第31行代码中对数据进行了一次归一化处理后,其范围是[0,1],所以需要将数据映射到[-1,1]。
![](https://epubservercos.yuewen.com/C60F01/16896237405619506/epubprivate/OEBPS/Images/txt002_35.jpg?sign=1738836529-Xx96mgp9r1hkG8B8RpIJi34NfyyDhB66-0-0a1f711599404eb89f8cd5de721dd982)
接下来定义模型,由于预训练好的“MobileNetV2”返回的数据维度为“(32,6,6,1280)”,其中“32”是一个批次(Batch)数据的大小,“6,6”代表输出的特征图的大小为6×6,“1280”代表该层使用了1280个卷积核。为了适应花朵分类任务,需要在“MobileNetV2”返回数据的基础上再增加两层网络层。
![](https://epubservercos.yuewen.com/C60F01/16896237405619506/epubprivate/OEBPS/Images/txt002_36.jpg?sign=1738836529-DYdL3iL8bH0NF6UzueaG9FwanCEpZAON-0-a1f1942bc333d2238b7387b0a0aba6f3)
全局平均池化(Global Average Pooling,GAP)是对每一个特征图求平均值,将该平均值作为该特征图池化后的结果,因此经过该操作后数据的维度变为(32,1280)。由于花朵分类任务是一个5分类的任务,因此需要再使用一个全连接(Dense),将维度变为(32,5)。
接着我们编译一下模型,同时指定使用的优化器和损失函数:
![](https://epubservercos.yuewen.com/C60F01/16896237405619506/epubprivate/OEBPS/Images/txt002_37.jpg?sign=1738836529-7jCJlQxzLJBRtKU32EZhzyhRLdHjcnm0-0-31619af70a6513a5cd367e47a94a2c1e)
“model.summary()”可以输出模型各层的参数概况,如图2-18所示。
![](https://epubservercos.yuewen.com/C60F01/16896237405619506/epubprivate/OEBPS/Images/txt002_39.jpg?sign=1738836529-h2hyG8W4I5NaQcdVWRmKcBT2Zsn27Gtk-0-c5d6cc295a097ee11e32af2fcdb47c7c)
图2-18 模型各层的参数概况
最后使用“model.fit”训练模型:
![](https://epubservercos.yuewen.com/C60F01/16896237405619506/epubprivate/OEBPS/Images/txt002_38.jpg?sign=1738836529-rh5WoNQMk6CQ2Rl5FdOHrgTfJu9vjcAq-0-4c6299531166ba03b42000380b330915)
这里参数“epochs”指定需要训练的回合数,“steps_per_epoch”代表每个回合要取多少个批次数据,通常“steps_per_epoch”的大小等于我们数据集的大小除以批次的大小后上取整。关于模型的训练部分,我们在2.3.2节中会详细介绍。
在本节中我们简单了解了“tf.data”API的使用,在后面章节的项目实战部分还会用到该API来构建数据输入管道,包括图片数据和文本数据等。
2.3.2 “tf.keras”API
Keras是一个基于Python编写的高层神经网络API,强调用户友好性、模块化及易扩展等,其后端可以采用TensorFlow、Theano及CNTK,目前大多是以TensorFlow作为后端引擎的。考虑到Keras优秀的特性及它的受欢迎程度,TensorFlow将Keras的代码吸收进来,并将其作为高级API提供给用户使用。“tf.keras”不强调原来Keras的后端可互换性,而是在符合Keras标准的基础上让其与TensorFlow结合得更紧密(例如支持TensorFlow的Eager Execution模式,支持“tf.data”,以及支持TPU训练等)。“tf.keras”提高了TensorFlow的易用性,同时也保持了TensorFlow的灵活性和性能。
1.基本模型的搭建和训练
可以使用“tf.keras.Sequential”来创建基本的网络模型。通过这种方式创建的模型又称为顺序模型,因为这种模型是由多个网络层线性堆叠而成的。
首先,导入需要的包:
![](https://epubservercos.yuewen.com/C60F01/16896237405619506/epubprivate/OEBPS/Images/txt002_40.jpg?sign=1738836529-wuTI8sY3oU748T1L6I4qrPmDointILeM-0-a4a8c3aa703b910327645ab81bcc922c)
然后,创建一个顺序模型:
![](https://epubservercos.yuewen.com/C60F01/16896237405619506/epubprivate/OEBPS/Images/txt002_41.jpg?sign=1738836529-tRpLG5imaaMVofj9oEF8QfFI1yHxujR9-0-9a001107b033832bc5fe2dc5532e00ae)
上面的代码中,在定义这个顺序模型的同时添加了相应的网络层,除此之外也可以使用“add”方法逐层添加:
![](https://epubservercos.yuewen.com/C60F01/16896237405619506/epubprivate/OEBPS/Images/txt002_43.jpg?sign=1738836529-Mu5UbFHD2eTAlndYYfyjE7hhaliwelAY-0-5bc5a866e56c05eb8816f8f286743054)
“tf.keras.layers”用于生成网络层,包括全连接层(tf.keras.layers.Dense())、Dropout层(tf.keras.layers.Dropout),以及卷积网络层(如二维卷积:tf.keras.layers.Conv2D)等。创建好网络结构后,要对网络进行编译:
![](https://epubservercos.yuewen.com/C60F01/16896237405619506/epubprivate/OEBPS/Images/txt002_44.jpg?sign=1738836529-KyJXMelIMx7ERZvObCSAi3KEaapBVXIQ-0-bd65878c91d13d74d15224508e24177b)
在编译模型的时候需要设置一些必需参数,例如“optimizers”用来指定我们想使用的优化器及设定优化器的学习率,如Adam优化器“tf.keras.optimizer.Adam”、SGD优化器“tf.keras.optimizer.SGD”等,在第15行代码中使用的是Adam优化器,并设置学习率为“0.001”。
“loss”参数用来设置模型的损失函数(又称目标函数),例如均方误差损失函数(mean_squared_error)、对数损失函数(binary_ crossentropy),以及多分类的对数损失函数(categorical_crossentropy),等等。
“metrics”用来设定模型的评价函数,模型的评价函数与损失函数相似,不过评价函数只用来显示给用户查看,并不用于模型的训练。除了自带的一些评价函数外,这里还可以使用自定义评价函数。
编译好模型之后就可以开始训练了,这里使用NumPy生成一组随机数作为训练数据:
![](https://epubservercos.yuewen.com/C60F01/16896237405619506/epubprivate/OEBPS/Images/txt002_45.jpg?sign=1738836529-Q1CzIiRLOGzdVknNId03Zv7Ede40dC0l-0-a8d71d0993cf9f9342bbb96fa018fc64)
第20行和第21行代码随机生成样本数据和类标。第25行代码使用“model.fit”来执行模型的训练,其中参数“data”和“labels”分别为训练数据和类标,“epochs”为训练的回合数(一个回合即在全量数据集上训练一次),“batch_size”为训练过程中每一个批次数据的大小。输出结果如图2-19所示。
![](https://epubservercos.yuewen.com/C60F01/16896237405619506/epubprivate/OEBPS/Images/txt002_46.jpg?sign=1738836529-IB05QHdEE5DFSNVt0PVGWdiLGPHM6iDF-0-b3074287b19f57ba6dae3630a0ae7593)
图2-19 输出结果
在训练模型的工程中,为了更好地调节参数,方便模型的选择和优化,通常会准备一个验证集。这里随机生成一个验证集:
![](https://epubservercos.yuewen.com/C60F01/16896237405619506/epubprivate/OEBPS/Images/txt002_47.jpg?sign=1738836529-NoCcq9D1JznGAFi3HjsuSs2CnJj7pJWA-0-3cb6b2484b41dc0da0967573926be14f)
输出结果如图2-20所示。
![](https://epubservercos.yuewen.com/C60F01/16896237405619506/epubprivate/OEBPS/Images/txt002_48.jpg?sign=1738836529-vxR0tppGJW9Y9MNKv2feI2TsGmMelkIW-0-5a0d8186cf985d036a724f8604281a0c)
图2-20 增加验证集后的输出结果
和图2-19相比,这里多了“val_loss”和“val_accuracy”,分别为验证集上的损失和准确率。
在上面的例子中,我们直接在NumPy数据上训练模型,也可以使用“tf.data”将其转为数据集后再传递给模型去训练:
![](https://epubservercos.yuewen.com/C60F01/16896237405619506/epubprivate/OEBPS/Images/txt002_49.jpg?sign=1738836529-co9CVY0aNZN89lcIFvZJ4bAUxGiJeksw-0-64732b33b2ac7ec520c0b668773d7841)
模型训练好之后,我们希望用验证集去对模型进行评估,这里可以使用“model.evaluate”对模型进行评估:
![](https://epubservercos.yuewen.com/C60F01/16896237405619506/epubprivate/OEBPS/Images/txt002_50.jpg?sign=1738836529-nGGAIy8EMMuNTtrqAuD24GIhs8Glnmmz-0-1cfd768a7490fc84edd8ddda5ceb21a8)
结果如图2-21所示。
![](https://epubservercos.yuewen.com/C60F01/16896237405619506/epubprivate/OEBPS/Images/txt002_51.jpg?sign=1738836529-e6QD4H2y0dTLsN9Vc6nDEEnG4tbADdRt-0-7e4d44c6ad86af68bc0da7658af483fc)
图2-21 模型评估结果
最后,使用“model.predict”对新的数据进行预测:
![](https://epubservercos.yuewen.com/C60F01/16896237405619506/epubprivate/OEBPS/Images/txt002_52.jpg?sign=1738836529-xFXrEQaJYVKmGuaX9ZK3JpxWjFZR78R3-0-2a0dedff8315edef668d92240c21a637)
结果如图2-22所示。
![](https://epubservercos.yuewen.com/C60F01/16896237405619506/epubprivate/OEBPS/Images/txt002_53.jpg?sign=1738836529-j1BFIJKofglCZTv4qyDbHquAdIybtQ5p-0-27ce00cc6297b3184627a30b3fe58122)
图2-22 使用训练好的模型预测新的数据
2.搭建高级模型
(1)函数式API
可以使用“tf.keras.Sequential”来搭建基本的网络结构,但更多的时候我们面临的是比较复杂的网络结构,例如,模型可能有多输入或多输出、模型中的某些网络层需要共享等,此时就需要用到函数式API。
实现一个简单的例子:
![](https://epubservercos.yuewen.com/C60F01/16896237405619506/epubprivate/OEBPS/Images/txt002_54.jpg?sign=1738836529-HTHjIiDDMFvOkwjAGGNlTG5YKYSBERsJ-0-77aee75fc3f454f2142db2ecac09abb4)
接下来使用上面定义的网络层来创建模型:
![](https://epubservercos.yuewen.com/C60F01/16896237405619506/epubprivate/OEBPS/Images/txt002_55.jpg?sign=1738836529-2UTCQDWEh9xOXW5VhPNZoHpFBUvbseDB-0-7682e6e1c6562334b166a27e8b45f44a)
(2)实现自定义的模型类和网络层
通过继承“tf.keras.Model”和“tf.keras.layers.Layer”可以实现自定义的模型类和网络层为我们构建自己的网络结构提供了非常好的灵活性。例如定义一个简单的前馈神经网络模型:,
![](https://epubservercos.yuewen.com/C60F01/16896237405619506/epubprivate/OEBPS/Images/txt002_56.jpg?sign=1738836529-z6kJx85Txm1tTCnKKLE0n3VZKgU9b2If-0-d04b189f518b3ff9950a713c11eb3160)
我们需要在“__init__”方法中定义好模型中所有的网络层,并作为模型类的属性。在“call”方法中可以定义模型的正向传递过程。之后就可以调用这个模型。
![](https://epubservercos.yuewen.com/C60F01/16896237405619506/epubprivate/OEBPS/Images/txt002_58.jpg?sign=1738836529-CnORioC4qR6yH0JfYsLMnZNcw666Qioh-0-9dd22572c62325cfe514d5743f74869f)
以上是我们自定义一个简单的网络模型的例子,通过继承“tf.keras.layers.Layer”类还可以实现自定义的网络层。
3.回调函数
回调函数会在模型的训练阶段被执行,可以用来自定义模型训练期间的一些行为,例如输出模型内部的状态等。我们可以自己编写回调函数,也可以使用内置的一些函数,例如:
●tf.keras.callbacks.ModelCheckpoint:定期保存模型。
●tf.keras.callbacks.LearningRateScheduler:动态地改变学习率。
●tf.keras.callbacks.EarlyStopping:当模型在验证集上的性能不再提升时终止训练。
●tf.keras.callbacks.TensorBoard:使用TensorBoard来监测模型。
回调函数的使用方式如下:
![](https://epubservercos.yuewen.com/C60F01/16896237405619506/epubprivate/OEBPS/Images/txt002_59.jpg?sign=1738836529-BnQb6mtSXRKoxQzPZrxXiYyfri9j0ien-0-469187c25ca8ceaf662a14df4b3cb78f)
4.模型的保存和恢复
使用“model.save()”和“tf.keras.models.load_model()”来保存和加载由“tf.keras”训练的模型:
![](https://epubservercos.yuewen.com/C60F01/16896237405619506/epubprivate/OEBPS/Images/txt002_60.jpg?sign=1738836529-UvkAGesvbVzab74oT5XBLCNS5W50Btri-0-52b7e5e382df45bf60702fae751cdaa0)
通过“model.save()”保存的是一个完整的模型信息,包括模型的权重和结构等。除保存完整的模型外,还可以单独保存模型的权重参数或者模型的结构。
![](https://epubservercos.yuewen.com/C60F01/16896237405619506/epubprivate/OEBPS/Images/txt002_61.jpg?sign=1738836529-lnL0ToWVmXLkHO6r4K853apf7YRC5FMI-0-f7243f8126fb000713d0fbe7ee042e6f)