tf.contrib.data.Iterator
如果tf.estimator.Estimator
也被使用,应该如何初始化?
问题之一是输入图(tf 图处理输入的部分)应该定义在intput_fn()
- 因为 tf.estimator 创建单独的图。
这个要求使得很难访问迭代器init ops
并传递它们to tf.estimator
(传递操作可以train/evaluate/predict
在以钩子形式调用时完成)。
tf.contrib.data.Iterator
如果tf.estimator.Estimator
也被使用,应该如何初始化?
问题之一是输入图(tf 图处理输入的部分)应该定义在intput_fn()
- 因为 tf.estimator 创建单独的图。
这个要求使得很难访问迭代器init ops
并传递它们to tf.estimator
(传递操作可以train/evaluate/predict
在以钩子形式调用时完成)。
使用SessionManager
as hook 可以解决同样的问题。
sm = tf.train.SessionManager(local_init_op=iterator_init_op)
...
estimator = tf.train.Estimator(...)
estimator.train(input_fn, hooks=[sm], steps=None, max_steps=None)
一种选择是将您包装input_fn
在另一个设置简单 SessionRunHook 的函数中init_hook
。所有操作都在内部定义input_fn
,它在与模型其余部分相同的图中被调用,但您可以从中将 设置iterator_init_op
为属性init_hook
。
def get_input_fn(mode="train"):
init_hook = IteratorInitHook()
def input_fn():
...
iterator = dataset.make_initializable_iterator()
init_hook.iterator_init_op = iterator.initializer
return input_fn, init_hook
class IteratorInitHook(tf.train.SessionRunHook):
def after_create_session(self, session, coord):
session.run(self.iterator_init_op)
现在,在构建 时Experiment
,您可以获得这些输入函数和 init 钩子,它们在创建训练/评估会话时被调用。它应该与estimator.train
.
train_input_fn, train_init_hook = get_input_fn("train")
test_input_fn, test_init_hook = get_input_fn("test")
return tf.contrib.learn.Experiment(
estimator=estimator,
train_input_fn=train_input_fn,
eval_input_fn=test_input_fn,
train_monitors=[train_init_hook],
eval_hooks=[test_init_hook],
)