0

在来自 tensorflow API 的ParameterServerTrainingmodel.fit教程代码中,在部分中有以下代码片段

def dataset_fn(input_context):
  global_batch_size = 64
  batch_size = input_context.get_per_replica_batch_size(global_batch_size)

  x = tf.random.uniform((10, 10))
  y = tf.random.uniform((10,))

  dataset = tf.data.Dataset.from_tensor_slices((x, y)).shuffle(10).repeat()
  dataset = dataset.shard(
      input_context.num_input_pipelines,
      input_context.input_pipeline_id)
  dataset = dataset.batch(batch_size)
  dataset = dataset.prefetch(2)

  return dataset

dc = tf.keras.utils.experimental.DatasetCreator(dataset_fn)

也有人说

The code in dataset_fn will be invoked on the input device, which is usually the CPU, on each of the worker machines.

这是否意味着数据集必须在每个工作服务器的同一存储上(比如参数服务器和工作服务器是不同的机器)?

或者,一台机器上的参数服务器有什么方法可以将训练数据发送给工作人员,而工作人员机器没有将数据集直接存储在我不明白的 ParameterServerStrategy 中?

4

1 回答 1

0

为了社区的利益,在这里回答。

从评论部分:

(如果有人有同样的疑问)经过进一步研究,我发现,我们可以在参数服务器存在的 1 个服务器上启动一个协调器,我们可以使用 ps 启动工作人员和参数 tf.distribute.Server(),这需要来自协调器的减少调用或训练调用。检查此链接 tensorflow.org/api_docs/python/tf/distribute/Server

于 2021-08-27T16:08:58.467 回答