训练好一个神经网络模型后,我们就希望能够应用在预测数据上。那么,如何把模型存储起来呢?同时,对于一个已经存储起来的模型,在将其应用在预测数据上时又如何加载呢?
Tensorflow的API提供了以下两种方式来存储和加载模型。
(1)生成检查点文件(checkpoint file),扩展名一般为.ckpt,通过tf.train.Saver对象上调用Saver.save()生成。它包含权重和其他在程序中定义的变量,不包含图结构。如果需要在另一个程序中使用,需要重新创建图形结构,并告诉Tensorflow如何处理这些权重。
下面就分“模型存储”和“图存储”来介绍这两种方式。在Tensorflow的高级API,如Keras中,也提供了更高级的语句来保存和加载模型。
模型的存储与加载
模型存储主要是建立一个tf.train.Saver()来保存变量,并且指定保存的位置,一般模型的扩展名为.ckpt。
下面我们定义一个新的神经网络,含两个全连接层和一个输出层,来训练MNIST数据集,并把训练好的模型存储起来。我们用MNIST数据集说明。
1.加载数据及定义模型
加载数据及定义模型的代码如下:
生成网络模型,得到预测值,代码如下:
定义损失函数,代码如下:
接下来训练刚才定义的模型,并把每一轮训练得到的参数都存储下来。
2.训练模型及存储模型
首先,我们定义一个存储路径,这里就用当前路径下的ckpt_dir目录,代码如下:
定义一个计数器,为训练轮数计数,代码如下:
当定义完所有变量后,调用tf.train.Saver()来保存和提取变量,其后面定义的变量将不会被存储,代码如下:
训练模型并存储,如下:
于是,在训练过程中,ckpt_dir下会出现16个文件,其中有5个model.ckpt-.data-00000-of-00001文件,是训练过程中保存的模型,5个model.ckpt-.meta文件,是训练过程中保存的元数据(Tensorflow默认只保存最近5个模型和元数据,删除前面没用的模型和元数据),5个model.ckpt-.index文件,代表迭代次数,以及一个检查点文本文件,里面保存着当前模型和最近的5个模型,内容如下:
model_checkpoint_path:"model.ckpe-60"
all_model_checkpoint_paths:"model.ckpt-56"
all_model_checkpoint_paths:"model.ckpt-57"
all_model_checkpoint_paths:"model.ckpt-58"
all_model_checkpoint_paths:"model.ckpt-59"
all_model_checkpoint_paths:"model.ckpt-60"
那么,假如在训练某个模型时突然因为某种原因,脚本停止运行了,或者机器重启了,是不是就要从头开始训练呢?我们知道,训练一个神经网络的时间都比较长,少则几个小时,多则几天,甚至几周。如果能将之前训练的参数保存下来,就可以在出现意外状况时接着上一次的地方开始训练。此外,每个固定的轮数在检查点保存一个模型(.ckpt文件),也有利于随时将模型拿出来预测,用前几次的预测效果就可以估计出神经网络究竟设计得怎么样。
3.加载模型
如果已有训练好的模型变量文件,可以用saver.restore来进行模型加载:
图的存储与加载
当仅保存图模型时,才将图写入二进制协议文件中,例如:
当读取时,又从协议文件中读取出来:
好看请点这里~
领取专属 10元无门槛券
私享最新 技术干货