我们正在 Tensorflow 上运行多 GPU 作业,并评估从基于队列的模型(使用 string_input_producer 接口)到新的 Tensorflow 数据集 API 的迁移。后者似乎提供了一种同时在训练和验证之间切换的更简单方法。
下面的一段代码显示了我们是如何做到这一点的。
train_dataset, train_iterator = get_dataset(train_files, batch_size, epochs)
val_dataset, val_iterator = get_dataset(val_files, batch_size, epochs)
is_validating = tf.placeholder(dtype=bool, shape=())
next_batch = tf.cond(is_validating,
lambda: val_iterator.get_next(),
lambda: train_iterator.get_next())
validation_tower = self.num_gpus - 1
tower_grads = []
for i in range(self.num_gpus):
with tf.variable_scope(tf.get_variable_scope(),reuse=(i > 0)):
with tf.device('/gpu:%d' % i), tf.name_scope('%s_%d' % ('gpu_', i)) as scope:
if i == validation_tower:
images, labels = next_batch
# Loss funcs snipped out
else:
images, labels = next_batch
# Loss funcs snipped out
get_dataset 函数构建数据集,设置映射函数和批量大小。它还构建了一个迭代器,但不初始化它。迭代器的初始化发生在会话开始之前。
is_validating 布尔值在会话运行时提供,每隔几步我们通过 feed_dict 将 is_validating 传递为 True 以使用验证数据集
我的问题是:
假设我有 8 个 GPU,所以我们在 7 个 GPU 上运行训练。对于这 7 个 GPU 中的每一个,迭代器是否从同一点前进,从而为所有 7 个 GPU 提供相同的数据?