我正在使用 tensorflow 联合 API 研究 federated_learning_for_image_classification.ipynb。
在示例中,我可以检查每个模拟客户训练的准确度、损失和总准确度、总损失。
但是没有检查点文件。
我想制作每个客户端检查点文件和总检查点文件。
然后比较客户端参数变量和总参数变量。
任何人都可以帮助我在 federated_learning_for_image_classification.ipynb 示例中制作检查点文件吗?
我正在使用 tensorflow 联合 API 研究 federated_learning_for_image_classification.ipynb。
在示例中,我可以检查每个模拟客户训练的准确度、损失和总准确度、总损失。
但是没有检查点文件。
我想制作每个客户端检查点文件和总检查点文件。
然后比较客户端参数变量和总参数变量。
任何人都可以帮助我在 federated_learning_for_image_classification.ipynb 示例中制作检查点文件吗?
要问的一个问题是您是否要比较TFF内的变量(作为联合计算的一部分)或事后/外部 TFF(在 Python 内分析)。
修改tff.utils.IterativeProcess
执行的构造tff.learning.build_federated_averaging_process
可能是一个好方法。事实上,我建议在 GitHub 上分叉简化实现tensorflow_federated/python/research/simple_fedavg/simple_fedavg.py
,而不是深入研究tff.learning
.
将执行 a更新的行从客户端更改为will 将给出所有客户端模型的列表,然后可以将其与全局模型进行比较。tff.fedetated_mean
tff.federated_collect
例子:
client_deltas = tff.federated_collect(client_outputs.weights_delta)
@tff.tf_computation(server_state.model.type_signature,
client_deltas.type_signature)
def compare_deltas_to_global(global_model, deltas):
for delta in deltas:
# do something with delta vs global_model
tff.federated_apply(compare_deltas_to_global, (server_state.model, client_deltas))