我想用 a 来管理我的培训,tf.estimator.Estimator
但在与tf.data
API 一起使用时遇到了一些麻烦。
我有这样的事情:
def model_fn(features, labels, params, mode):
# Defines model's ops.
# Initializes with tf.train.Scaffold.
# Returns an tf.estimator.EstimatorSpec.
def input_fn():
dataset = tf.data.TextLineDataset("test.txt")
# map, shuffle, padded_batch, etc.
iterator = dataset.make_initializable_iterator()
return iterator.get_next()
estimator = tf.estimator.Estimator(model_fn)
estimator.train(input_fn)
由于我不能将 amake_one_shot_iterator
用于我的用例,我的问题是它input_fn
包含一个应该在其中初始化的迭代器model_fn
(这里,我tf.train.Scaffold
用来初始化本地操作)。
另外,我知道我们不能只使用input_fn = iterator.get_next
其他操作,否则其他操作将不会添加到同一个图中。
初始化迭代器的推荐方法是什么?