4

我想加快使用 Estimator API 和 input_fn 编写的训练例程tf.data.Dataset

我的实现需要 2 秒来准备一批数据,然后在 GPU 上运行 1 秒训练,然后重新开始准备一批数据。这真的是低效的。

我正在寻找一种方法来异步准备批次并将它们上传到 GPU 以加快训练速度。或者,另一种方法是在调用之间缓存数据集input_fn(这dataset.cache()似乎不是一个好的选择,因为必须在每次 input_fn 调用时重新创建数据集)。

这是我的代码的简化版本:

def input_fn(filenames, labels, epochs):
  dataset = tf.data.Dataset.from_tensor_slices((filenames, labels))
  dataset = dataset.map(_read_wav, num_parallel_calls=num_map_threads)
  if shuffle:
     dataset = dataset.shuffle(buffer_size=len(labels))
  dataset = dataset.map(_post_process,  num_parallel_calls=num_map_threads)
  dataset = dataset.map(lambda wav, label: ({'wav': wav}, label))
  dataset = dataset.batch(128)
  dataset = dataset.repeat(epochs) # to iterate over the training set forever
  iterator = dataset.dataset.make_one_shot_iterator()
  features, labels = iterator.get_next()
  return features, labels

train_input_fn = lambda : input_fn(train_files, train_labels, None)
eval_input_fn = lambda : input_fn(eval_files, eval_labels, 1)

train_spec = tf.estimator.TrainSpec(input_fn=train_input_fn, max_steps=45000)
eval_spec = tf.estimator.EvalSpec(input_fn=eval_input_fn) 
tf.estimator.train_and_evaluate(estimator, train_spec, eval_spec)

我注意到 Estimator API 正在积极开发中,并且在 tensorflow 的主分支中 input_fn 已经可以返回数据集,所以也许我问得太早了,这个功能还没有准备好。但如果是这样,请提供可以跟踪此实施的票证。

4

2 回答 2

11

Using tf.data.Dataset.cache() is indeed not a good choice since it will cache the whole dataset into memory, which takes time and might overflow your memory.

The way to go is to use tf.data.Dataset.prefetch() at the end of your pipeline, which will always make sure that the data pipeline holds buffer_size elements. It is usually enough to have buffer_size = 1 at the end:

dataset = ...
dataset = dataset.batch(128)
dataset = dataset.prefetch(1)  # prefetch one batch

As explained by @mrry in this answer, you can also try to increase the number of prefetched batches a bit.

Typically it is most useful to add a small prefetch buffer (with perhaps just a single element) at the very end of the pipeline, but more complex pipelines can benefit from additional prefetching, especially when the time to produce a single element can vary.


If you still have a slow input pipeline compared to your GPU computations, you need to increase the number of threads working in parallel using the num_parallel_calls argument of tf.data.Dataset.map().

于 2018-01-04T17:37:56.020 回答
1

Olivier 的回答要补充几点,主要来自这篇文章

  • repeatbeforeshuffle稍微快一点,在模糊的时代边界的不利方面。这在极少数情况下可能很重要,但我对此表示怀疑。
  • shufflemapping 之前 - 这减少了 shuffle 缓冲区大小的内存占用,因为它只需要缓冲文件名而不是文件内容。
  • 对我来说,将第三个地图变换应用于输出而不是数据集更有意义get_next()- 不确定这是否会影响速度。您还可以考虑将其他两个地图调用放在同一个中以减少调度问题。
  • 在ingrepeat之前进行实验。batch可能不会有影响,但可能很小。如果您repeat之前shuffle如上所述,您将不得不这样做。
  • 正如 Olivier 提到的,使用prefetch.

修改后的代码:

def input_fn(filenames, labels, epochs):
  dataset = tf.data.Dataset.from_tensor_slices((filenames, labels))
  dataset = dataset.repeat(epochs)
  if shuffle:
    dataset = dataset.shuffle(buffer_size=len(labels))

  def combined_map_fn(*args):
    return _post_process(_read_wav(*args))

  dataset = dataset.map(combined_map_fn, num_parallel_calls=num_map_threads)
  dataset = dataset.batch(128)
  dataset = dataset.prefetch(1)

  iterator = dataset.dataset.make_one_shot_iterator()
  wavs, labels = iterator.get_next()
  features = {'wav': wavs}
  return features, labels
于 2018-03-17T05:54:31.383 回答