0

我正在联邦学习设置中训练一个 CNN 模型。我更新我的local_models,然后我平均weights我从local_model更新中得到weightglobal_model。由于我在训练loss.backward()期间使用local_model,我可以使用以下方法获得渐变:

updates_list = []
for item in model.parameters():
    updates_list.append(copy.deepcopy(item.grad))
print(updates_list)

但我不能.grad用于我的,因为在权重平均期间global_model没有函数使用。我的问题是,我如何获得我的更新渐变?backward()local_modelglobal_model

4

0 回答 0