我想在 tensorflow 联合教程https://www.tensorflow.org/federated/tutorials/federated_learning_for_image_classification中打印客户端的本地输出。我应该怎么办?
问问题
272 次
2 回答
0
如果您只想要进入聚合的值列表(例如 into tff.federated_mean
),一种选择是添加额外的输出aggregate_mnist_metrics_across_clients()
以包含使用 计算的指标tff.federated_collect()
。
这可能看起来像:
@tff.federated_computation
def aggregate_mnist_metrics_across_clients(metrics):
return {
'num_examples': tff.federated_sum(metrics.num_examples),
'loss': tff.federated_mean(metrics.loss, metrics.num_examples),
'accuracy': tff.federated_mean(metrics.accuracy, metrics.num_examples),
'per_client/num_examples': tff.federated_collect(metrics.num_examples),
'per_client/loss': tff.federated_collect(metrics.loss),
'per_client/accuracy': tff.federated_collect(metrics.accuracy),
}
当计算运行时,它将在稍后打印几个单元格:
state, metrics = iterative_process.next(state, federated_train_data)
print('round 1, metrics={}'.format(metrics))
round 1, metrics=<...,per_client/accuracy=[0.14516129, 0.10642202, 0.13972603],per_client/loss=[3.2409852, 3.417463, 2.9516447],per_client/num_examples=[930.0, 1090.0, 730.0]>
但是请注意:如果您想知道特定客户的价值,则故意无法做到这一点。通过设计,TFF 的语言有意避免客户身份的概念;希望避免使客户可寻址。
于 2019-05-28T03:41:58.750 回答
0
如果你想在 'client_update' 函数中打印一些东西,你可以使用tf.print()
.
于 2022-02-15T12:30:32.927 回答