首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >使用tf.trainable_variables()保存convert_variables_to_constants

使用tf.trainable_variables()保存convert_variables_to_constants
EN

Stack Overflow用户
提问于 2017-08-20 06:49:08
回答 1查看 2.1K关注 0票数 15

我有一个Keras模型,我想将它转换为Tensorflow protobuf (例如saved_model.pb)。

该模型来自vgg-19网络上的转移学习,在vgg-19网络中,将头和头切断,并使用完全连接的+softmax层进行训练,而其余的vgg-19网络则被冻结。

我可以在Keras中加载模型,然后使用keras.backend.get_session()在tensorflow中运行模型,生成正确的预测:

代码语言:javascript
复制
frame = preprocess(cv2.imread("path/to/img.jpg")
keras_model = keras.models.load_model("path/to/keras/model.h5")

keras_prediction = keras_model.predict(frame)

print(keras_prediction)

with keras.backend.get_session() as sess:

    tvars = tf.trainable_variables()

    output = sess.graph.get_tensor_by_name('Softmax:0')
    input_tensor = sess.graph.get_tensor_by_name('input_1:0')

    tf_prediction = sess.run(output, {input_tensor: frame})
    print(tf_prediction) # this matches keras_prediction exactly

如果我不包括行tvars = tf.trainable_variables(),那么tf_prediction变量是完全错误的,与来自keras_prediction的输出完全不匹配。实际上,输出中的所有值(具有4个概率值的单个数组)都是完全相同的(~0.25,都添加到1)。这让我怀疑,如果没有首先调用tf.trainable_variables(),那么头的权重就会初始化为0,这是在检查模型变量之后确认的。在任何情况下,调用tf.trainable_variables()都会导致tensorflow预测正确。

问题是,当我试图保存这个模型时,tf.trainable_variables()中的变量实际上并没有保存到.pb文件中:

代码语言:javascript
复制
with keras.backend.get_session() as sess:
    tvars = tf.trainable_variables()

    constant_graph = graph_util.convert_variables_to_constants(sess, sess.graph.as_graph_def(), ['Softmax'])
    graph_io.write_graph(constant_graph, './', 'saved_model.pb', as_text=False)

我要问的是,如何才能将Keras模型保存为tf.training_variables()完整的Tensorflow原型呢?

非常感谢!

EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2017-08-28 18:53:15

因此,冻结图形中的变量(转换为常量)的方法应该有效,但这不是必要的,而且比其他方法更棘手。(下文对此作更多介绍)。如果您希望图形因某种原因而冻结(例如导出到移动设备),我需要更多的细节来帮助调试,因为我不知道Keras在幕后用图形做什么。但是,如果您希望稍后只保存和加载一个图,我可以解释如何做到这一点(尽管不能保证Keras所做的任何操作都不会搞砸它.,我很乐意帮助调试它)。

实际上这里有两种格式。一个是GraphDef,它用于检查点,因为它不包含关于输入和输出的元数据。另一个是包含元数据和图def的MetaGraphDef,元数据用于预测和运行ModelServer (来自tensorflow/serving)。

在这两种情况下,您需要做的不仅仅是调用graph_io.write_graph,因为变量通常存储在图的外部。

这两种用例都有包装器库。tf.train.Saver主要用于保存和恢复检查点。

但是,由于您想要预测,我建议使用tf.saved_model.builder.SavedModelBuilder来构建SavedModel二进制文件。我已经为此提供了一些锅炉板:

代码语言:javascript
复制
from tensorflow.python.saved_model.signature_constants import DEFAULT_SERVING_SIGNATURE_DEF_KEY as DEFAULT_SIG_DEF
builder = tf.saved_model.builder.SavedModelBuilder('./mymodel')
with keras.backend.get_session() as sess:
  output = sess.graph.get_tensor_by_name('Softmax:0')
  input_tensor = sess.graph.get_tensor_by_name('input_1:0')
  sig_def = tf.saved_model.signature_def_utils.predict_signature_def(
    {'input': input_tensor},
    {'output': output}
  )
  builder.add_meta_graph_and_variables(
      sess, tf.saved_model.tag_constants.SERVING,
      signature_def_map={
        DEFAULT_SIG_DEF: sig_def
      }
  )
builder.save()

运行此代码后,您应该有一个mymodel/saved_model.pb文件和一个目录mymodel/variables/,其中包含与变量值相对应的protobufs。

然后再次加载模型,只需使用tf.saved_model.loader

代码语言:javascript
复制
# Does Keras give you the ability to start with a fresh graph?
# If not you'll need to do this in a separate program to avoid
# conflicts with the old default graph
with tf.Session(graph=tf.Graph()):
  meta_graph_def = tf.saved_model.loader.load(
      sess, 
      tf.saved_model.tag_constants.SERVING,
      './mymodel'
  )
  # From this point variables and graph structure are restored

  sig_def = meta_graph_def.signature_def[DEFAULT_SIG_DEF]
  print(sess.run(sig_def.outputs['output'], feed_dict={sig_def.inputs['input']: frame}))

显然,通过tensorflow/serving或Cloud引擎可以更有效地预测此代码,但这应该有效。可能Keras正在做一些会干扰这个过程的事情,如果是的话,我们希望了解它(并且我希望确保Keras用户也能够冻结图形,所以如果您想给我发送一个包含完整代码的gist或者其他什么东西,我可能会找到一个熟悉Keras的人来帮助我调试)。

编辑:您可以在这里找到一个端到端的示例:https://github.com/GoogleCloudPlatform/cloudml-samples/blob/master/census/keras/trainer/model.py#L85

票数 5
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/45779268

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档