比方说我有
sequence = np.array([[1],[2],[3],[4],[5]])
我将生成器定义为
def generator():
for el in sequence:
yield el
现在,我希望使用Tensorflow中定义的from_generator()从生成器中检索数据。
dataset = tf.data.Dataset().from_generator(generator,
output_types= tf.int64,
output_shapes=(tf.TensorShape([1])))
iterator = dataset.make_initializable_iterator()
el = iterator.get_next()
为了检索我用过的,
with tf.Session() as sess:
sess.run(iterator.initializer)
print(sess.run(el))
print(sess.run(el))
print(sess.run(el))
print(sess.run(el))
print(sess.run(el))
有没有一种方法可以使用循环来获取'el‘,而不是每次都执行sess.run(el)?
发布于 2019-08-07 19:30:00
这应该会达到您想要的效果:
with tf.Session() as sess:
sess.run(iterator.initializer)
try:
while True:
print(sess.run(el))
except tf.errors.OutOfRangeError:
print("Iterating finished")
pass
https://stackoverflow.com/questions/57401089
复制