我最近开始使用tensorflow,更具体地说,是使用新的dataset API。通过将数据集的迭代器插入到表示输入和标签的图的节点,我已经成功地使用数据集将训练数据提供给我的简单模型。类似于:
input = input_dataset.make_one_shot_iterator().get_next()
label = label_dataset.make_one_shot_iterator().get_next()
现在我在想,当我必须对用户输入进行推断时,该怎么办,即,用户给了我一个输入值,而我必须做出预测。如果我有一个占位符,我只会将用户输入放在一个feed_dict中,但是对于dataset api,我几乎不知道如何做类似的事情。我是否应该有一个单独的图形,仅用于推断,其中我的input
变量是占位符?
我已经尝试做了一个here描述的可馈送迭代器,但这只适用于字符串的占位符,而我的输入是int32。
谢谢你的建议。
发布于 2018-06-06 18:50:06
为此,tensorflow提供了tf.placeholder_with_default
应用程序接口
# Create a Dataset
dataset = tf.data.Dataset.zip((input_dataset, label_dataset)).batch(32).repeat(...)
# Create Iterator
input, label = dataset.make_one_shot_iterator()
# Create Placholders
x = tf.placeholder_with_default(input, shape=[...], name='input')
y = tf.placeholder_with_default(label, shape-[...], name='label')
def nn_model(features, labels):
logits = ...
loss = tf.reduce_sum(tf.nn.softmax_cross_entropy_with_logits_v2(labels=labels, logits=logits))
optimizer = tf.train.AdamOptimizer(learning_rate=0.01).minimize(loss)
return optimizer, loss
# Create Model
train_op, loss_op = nn_model(x, y)
# Training
sess.run(train_op)
# Inference
sess.run(logits, feed_dict={x:..., y:...})
https://stackoverflow.com/questions/46804666
复制相似问题