我正在分析我在 Tensorflow 中与 FedAvg 联合实施的一种方法。我需要为与服务器通信的每个客户端的增量权重创建一个直方图。每个客户端分别调用simulation/federated_avaraging.py
,但问题是我不能在那里调用以下 API。tf.summary.histogram()
. 任何帮助,将不胜感激。
1 回答
In TFF, TensorFlow represents "local computation"; so if you need a way to inspect something across clients, you will need to first aggregate the values you want via TFF, or inspect the returned values in native python.
If you want to use TF ops, I would recommend using the tff.federated_collect
intrinsic, to "gather" all the values you want on the server, then federated_map
a TF function which takes these values and produces your desired visualization.
If you would rather work at the Python level, there is an easy option here (this is the approach I would take): simply return the results of training at the clients from your tff.federated_computation
; when you invoke this computation, this will materialize a Python list of these results, and you can visualize it however you want. This would be roughly along the lines of something like:
@tff.federated_computation(...)
def train_one_round(...):
...
trained_clients = run_training(...)
new_model = update_global_model(trained_clients,...)
return new_model, trained_clients
In this example, this function will return a tuple, the second element of which is a Python list representing the results of training at all clients.