在来自 tensorflow API 的ParameterServerTrainingmodel.fit
教程代码中,在部分中有以下代码片段
def dataset_fn(input_context):
global_batch_size = 64
batch_size = input_context.get_per_replica_batch_size(global_batch_size)
x = tf.random.uniform((10, 10))
y = tf.random.uniform((10,))
dataset = tf.data.Dataset.from_tensor_slices((x, y)).shuffle(10).repeat()
dataset = dataset.shard(
input_context.num_input_pipelines,
input_context.input_pipeline_id)
dataset = dataset.batch(batch_size)
dataset = dataset.prefetch(2)
return dataset
dc = tf.keras.utils.experimental.DatasetCreator(dataset_fn)
也有人说
The code in dataset_fn will be invoked on the input device, which is usually the CPU, on each of the worker machines.
这是否意味着数据集必须在每个工作服务器的同一存储上(比如参数服务器和工作服务器是不同的机器)?
或者,一台机器上的参数服务器有什么方法可以将训练数据发送给工作人员,而工作人员机器没有将数据集直接存储在我不明白的 ParameterServerStrategy 中?