我想加快使用 Estimator API 和 input_fn 编写的训练例程tf.data.Dataset
。
我的实现需要 2 秒来准备一批数据,然后在 GPU 上运行 1 秒训练,然后重新开始准备一批数据。这真的是低效的。
我正在寻找一种方法来异步准备批次并将它们上传到 GPU 以加快训练速度。或者,另一种方法是在调用之间缓存数据集input_fn
(这dataset.cache()
似乎不是一个好的选择,因为必须在每次 input_fn 调用时重新创建数据集)。
这是我的代码的简化版本:
def input_fn(filenames, labels, epochs):
dataset = tf.data.Dataset.from_tensor_slices((filenames, labels))
dataset = dataset.map(_read_wav, num_parallel_calls=num_map_threads)
if shuffle:
dataset = dataset.shuffle(buffer_size=len(labels))
dataset = dataset.map(_post_process, num_parallel_calls=num_map_threads)
dataset = dataset.map(lambda wav, label: ({'wav': wav}, label))
dataset = dataset.batch(128)
dataset = dataset.repeat(epochs) # to iterate over the training set forever
iterator = dataset.dataset.make_one_shot_iterator()
features, labels = iterator.get_next()
return features, labels
train_input_fn = lambda : input_fn(train_files, train_labels, None)
eval_input_fn = lambda : input_fn(eval_files, eval_labels, 1)
train_spec = tf.estimator.TrainSpec(input_fn=train_input_fn, max_steps=45000)
eval_spec = tf.estimator.EvalSpec(input_fn=eval_input_fn)
tf.estimator.train_and_evaluate(estimator, train_spec, eval_spec)
我注意到 Estimator API 正在积极开发中,并且在 tensorflow 的主分支中 input_fn 已经可以返回数据集,所以也许我问得太早了,这个功能还没有准备好。但如果是这样,请提供可以跟踪此实施的票证。