我创建了一个这样的 CNN 模型global_model = CNNMnist(args=args)
。然后我将它发送到设备,将其设置为训练。然后我训练我的本地模型,收集 local_weights 和它们的平均值以获得更新的 global_model。现在我正在尝试从.parameters()
函数中获取项目,但我得到的None
只是item.grad
. 当我为 local_models 做同样的事情时,我得到了想要的输出。我究竟做错了什么?
global_model.to(device)
global_model.train()
...................
global_weights = average_weights(local_weights)
global_model.load_state_dict(global_weights)
last_update = []
for item in global_model.parameters():
last_update.append(copy.deepcopy(item.grad))
print(item.grad)
Output: None None None None None None None None
任何帮助,将不胜感激。