Python+Tensorflow机器学习实战
上QQ阅读APP看书,第一时间看更新

3.2 存储和加载模型

机器学习中最关键的就是模型的设计与训练。

在模型的设计和训练过程中,会耗费大量的时间。为了降低训练过程中因意外情况发生而造成的不良影响,我们会对训练过程中的模型进行定期存储。

为了保证意外中断的模型能够继续训练以及训练完成的模型在其他数据上能够直接使用,我们会对存储的模型进行加载。

TensorFlow中提供了tf.train.Saver类来实现训练模型的保存和加载。tf.train.Saver类的save()方法将TensorFlow模型保存到指定路径中,该类的restore()方法用来加载这个已保存的TensorFlow模型。

3.2.1 存储模型

TensorFlow模型包括计算图以及计算过程中的值,主要包括计算图中的所有变量、操作等,以及计算过程中的权重、偏差、梯度等值的更新结果。

在TensorFlow中,提供了tf.train.Saver类来完成模型的存储。

  
     saver = tf.train.Saver(max_to_keep, keep_checkpoint_every_n_hours)

在创建类时,可以指定最多可保留的模型数、训练过程每隔多长时间进行一次自动保存等。

tf.train.Saver类提供了save()方法来实现保存工作。在该方法中需要说明会话、所保存模型的名称,以及每次保存模型时间隔的迭代学习次数等。需要注意的是,一旦调用该方法,其后定义的变量将不会被保存。具体实现如下:

  
     saver.save(sess, 'my-model', global_step=step,write_meta_graph=False)

接下来,使用saver类保存训练模型,该模型实现了5*(w1+w2)的计算,具体实现如下:

上述代码获取值1和2并分别填充到w1和w2中,然后计算其和的5倍。运行该代码,打印输出如下:

  
     15.0
     --------------

可以看到,不仅有控制台的打印输出,还在代码所在文件夹中出现了四个文件,如图3.5所示。

图3.5 存储的文件

这四个文件分别存储训练过程中的不同信息,其中:

.meta文件保存TensorFlow计算图的结构信息。

data和.index文件存储训练好的参数。

3.2.2 加载模型

加载存储好的模型,包括加载模型和加载训练参数两步。

TensorFlow提供了tf.train.import()相关方法来加载已存储的模型,如import_meta_graph()方法:

  
     saver = tf.train.import_meta_graph('my_test_model-1000.meta')

其中,“my_test_model-1000.meta”文件就是已存储的模型文件。

完成计算图的加载后,还需要加载训练过的参数的值,这需要使用restore()方法。

完成模型的加载后,可以将待训练数据放入模型中进行训练。例如,加载刚才存储的模型5*(w1+w2),将待训练数据13、17分别放入模型的w1和w2中进行计算。具体实现如下:

运行上述代码,打印输出(13+17)*5的最终结果,如图3.6所示。

图3.6 打印结果

可以很明显地看到,所存储的模型已加载,并使用新的数据w1=13和w2=17,完成5*(w1+w2)的计算。