我想feedable
在 tensorflow Dataset API 中使用迭代器设计,所以我可以在一些训练步骤后切换到验证数据。但如果我切换到验证数据,它将结束整个会话。
以下代码演示了我想要做什么:
import tensorflow as tf
graph = tf.Graph()
with graph.as_default():
training_ds = tf.data.Dataset.range(32).batch(4)
validation_ds = tf.data.Dataset.range(8).batch(4)
handle = tf.placeholder(tf.string, shape=[])
iterator = tf.data.Iterator.from_string_handle(
handle, training_ds.output_types, training_ds.output_shapes)
next_element = iterator.get_next()
training_iterator = training_ds.make_initializable_iterator()
validation_iterator = validation_ds.make_initializable_iterator()
with graph.as_default():
with tf.train.MonitoredTrainingSession() as sess:
training_handle = sess.run(training_iterator.string_handle())
validation_handle = sess.run(validation_iterator.string_handle())
sess.run(training_iterator.initializer)
count_training = 0
while not sess.should_stop():
x = sess.run(next_element, feed_dict={handle: training_handle})
count_training += 1
print('{} [training] {}'.format(count_training, x.shape))
# print(x)
# we do periodic validation
if count_training % 4 == 0:
sess.run(validation_iterator.initializer)
count_validation = 0
while not sess.should_stop():
y = sess.run(next_element, feed_dict={handle: validation_handle})
count_validation += 1
print(' {} [validation] {}'.format(count_validation, y.shape))
# print(y)
训练数据有 32 个元素,用 4 个进行批处理,所以我们每 4 个步骤进行 8 个批处理,所以我期望:
# 1 [training]
# 2 [training]
# 3 [training]
# 4 [training]
# 1 [validation]
# 2 [validation]
# 5 [training]
# 6 [training]
# 7 [training]
# 8 [training]
# 1 [validation]
# 2 [validation]
但是当第一次验证完成时它会停止:
# 1 [training]
# 2 [training]
# 3 [training]
# 4 [training]
# 1 [validation]
# 2 [validation]
那么,如何在 中使用这个feedable
迭代器tf.MonitoredTrainingSession
呢?