0

我正在使用 tensorflow 联合 API 研究 federated_learning_for_image_classification.ipynb。

在示例中,我可以检查每个模拟客户训练的准确度、损失和总准确度、总损失。

但是没有检查点文件。

我想制作每个客户端检查点文件和总检查点文件。

然后比较客户端参数变量和总参数变量。

任何人都可以帮助我在 federated_learning_for_image_classification.ipynb 示例中制作检查点文件吗?

4

1 回答 1

1

要问的一个问题是您是否要比较TFF的变量(作为联合计算的一部分)或事后/外部 TFF(在 Python 内分析)。

修改tff.utils.IterativeProcess执行的构造tff.learning.build_federated_averaging_process可能是一个好方法。事实上,我建议在 GitHub 上分叉简化实现tensorflow_federated/python/research/simple_fedavg/simple_fedavg.py,而不是深入研究tff.learning.

将执行 a更新的行从客户端更改为will 将给出所有客户端模型的列表,然后可以将其与全局模型进行比较。tff.fedetated_meantff.federated_collect

例子:

client_deltas = tff.federated_collect(client_outputs.weights_delta)

@tff.tf_computation(server_state.model.type_signature,
                    client_deltas.type_signature)
def compare_deltas_to_global(global_model, deltas):
  for delta in deltas:
    # do something with delta vs global_model 

tff.federated_apply(compare_deltas_to_global, (server_state.model, client_deltas))
于 2019-10-22T12:48:14.187 回答