2

I'm adapting code that I used in TF 1.2 using estimator API with TFRecords. But the input_fn is always returning just the first batch, and thus never finishes (or progresses).

def gen_input(filename):
  def decode(line):
    features = {
      'x': tf.FixedLenFeature((3,), tf.float32),
      'y': tf.FixedLenFeature((), tf.int64)
    }
    parsed = tf.parse_single_example(line, features)
    return parsed['x'], parsed['y']

  def input_fn():
    dataset = (tf.data.TFRecordDataset([filename])).map(decode)
    dataset = dataset.repeat(1)
    dataset = dataset.batch(2)

    iterator = dataset.make_one_shot_iterator()
    features, labels = iterator.get_next()
    return features, labels

  return input_fn


estimator = tf.estimator.Estimator(
    model_dir=model_dir,
    model_fn=model_fn,
    params={})

train_input_fn = gen_input('train.tfrecord')
eval_input_fn = gen_input('eval.tfrecord')

train_spec = tf.estimator.TrainSpec(input_fn=train_input_fn)
eval_spec = tf.estimator.EvalSpec(input_fn=eval_input_fn)

tf.estimator.train_and_evaluate(estimator, train_spec, eval_spec)

My TFRecord only has 5 items in it, so doing dataset.batch(2) along with dataset.repeat(1) should make it so the model finishes after 3 steps. Furthermore, I'm logging the features being provided to my model_fn at each step into Tensorboard. The values being logged are always the same.

What am I doing wrong here?

enter image description here

4

0 回答 0