我正在联邦学习设置中训练一个 CNN 模型。我更新我的local_models,然后我平均weights我从local_model更新中得到weight的global_model。由于我在训练loss.backward()期间使用local_model,我可以使用以下方法获得渐变:
updates_list = []
for item in model.parameters():
updates_list.append(copy.deepcopy(item.grad))
print(updates_list)
但我不能.grad用于我的,因为在权重平均期间global_model没有函数使用。我的问题是,我如何获得我的更新渐变?backward()local_modelglobal_model