星标或者置顶【OpenCV学堂】
干货教程第一时间送达!
tf.train.Saver API说明
保存于恢复变量,对定义好完成训练或者完成部分训练的计算图所有OP操作的中间变量进行保存,保存为检查点文件(checkpoint file),检查点文件通过restore方法完成恢复,实现从变量到张量值(tensor value)得映射加载,可以进行调用或者继续训练。同时Saver支持全局步长参数,通过对不同的step自动保存为检查点
上述代码表示分别在step=0与step=1000的时候保存检查点。
Saver在保存检查点的时候默认保存计算图的全部变量,但是可以通过var_list来决定保存多少个变量到检查点文件中去。对保存的检查点进行恢复可以调用如下的方法:
从检查点恢复变量并映射到相关的tensor中去,要求必须有一个当前会话才可以重新加载计算图。当使用这种方式时候就无需再重复调用初始化方法来初始化变量了,restore方法本身就完成了变量初始化,然后就可以继续训练或者使用计算图进行预测。
预测图导出
使用tf.train.Saver会保存检测点文件,但是这些文件不是一个,是四个文件一组:
其中
prefix是前缀名称
steps是运行number of steps
当prefix=my_cnn_mnist,steps=10000时
通过读取checkpint文件与meta文件加载计算图,然后把所有的变量转换为常量形式通过GFile进行串行化写入生成预测图(PB文件),从检查点导出成为预测图(PB文件)的代码如下:
这段代码我也是借鉴tensorflow中一个工具类copy过来的,发现很好用!
一个例子
首先定义个网络模型,对于输入与预测部分tensor的name属性我们都给予赋值。
定义输入-X
定义预测输出
构建卷积神经网络的代码如下
保存检查点的代码如下:
导出预测图之后使用预测实现手写数字预测的代码如下
运行结果:
天下难事,必作于易
天下大事,必作于细
欢迎扫码加入【OpenCV研习社】
领取专属 10元无门槛券
私享最新 技术干货