我正在构建一个 TensorFlow Estimator,我想使用该tf.estimator.train_and_evaluate()
函数进行训练和评估。此功能的文档提供以下建议:
还建议在执行评估之前对模型进行更长时间的训练,比如多个 epoch,因为每次训练的输入管道都是从头开始的。
这是有道理的,因为train_and_evaluate()
通过在调用estimator.train()
和之间交替工作estimator.evaluate()
,为每个新调用拆除计算图。就我而言,这是一个问题,因为我想相对频繁地评估模型,而且我input_fn
的设置似乎有很多开销。它目前看起来像这样:
def input_fn():
# Build dataset from generator
dataset = tf.data.Dataset.from_generator(
generator=instance_generator,
output_types=types,
output_shapes=shapes,
)
dataset = dataset.shuffle(buffer_size=dataset_size)
dataset = dataset.repeat(epochs_per_eval)
dataset = dataset.batch(batch_size)
dataset = dataset.prefetch(1)
return dataset
我怀疑这个函数的很多时间成本来自于洗牌,因为它需要首先生成整个数据集。洗牌可能并不慢,但我instance_generator
的是。理想情况下,我想找到一种方法来避免必须为每个 train/eval 调用从生成器重建数据集。有什么方法可以使用 Dataset 类来实现这一点?有没有一种方法可以在数据集生成后缓存它的状态,以便input_fn
在第一次调用之后的每个新调用都变得更便宜?