I'm adapting code that I used in TF 1.2 using estimator API with TFRecords. But the input_fn
is always returning just the first batch, and thus never finishes (or progresses).
def gen_input(filename):
def decode(line):
features = {
'x': tf.FixedLenFeature((3,), tf.float32),
'y': tf.FixedLenFeature((), tf.int64)
}
parsed = tf.parse_single_example(line, features)
return parsed['x'], parsed['y']
def input_fn():
dataset = (tf.data.TFRecordDataset([filename])).map(decode)
dataset = dataset.repeat(1)
dataset = dataset.batch(2)
iterator = dataset.make_one_shot_iterator()
features, labels = iterator.get_next()
return features, labels
return input_fn
estimator = tf.estimator.Estimator(
model_dir=model_dir,
model_fn=model_fn,
params={})
train_input_fn = gen_input('train.tfrecord')
eval_input_fn = gen_input('eval.tfrecord')
train_spec = tf.estimator.TrainSpec(input_fn=train_input_fn)
eval_spec = tf.estimator.EvalSpec(input_fn=eval_input_fn)
tf.estimator.train_and_evaluate(estimator, train_spec, eval_spec)
My TFRecord only has 5 items in it, so doing dataset.batch(2)
along with dataset.repeat(1)
should make it so the model finishes after 3 steps. Furthermore, I'm logging the features
being provided to my model_fn
at each step into Tensorboard. The values being logged are always the same.