0

我创建了一个这样的 CNN 模型global_model = CNNMnist(args=args)。然后我将它发送到设备,将其设置为训练。然后我训练我的本地模型,收集 local_weights 和它们的平均值以获得更新的 global_model。现在我正在尝试从.parameters()函数中获取项目,但我得到的None只是item.grad. 当我为 local_models 做同样的事情时,我得到了想要的输出。我究竟做错了什么?

global_model.to(device)
global_model.train()
...................
global_weights = average_weights(local_weights)
global_model.load_state_dict(global_weights)
last_update = []
for item in global_model.parameters():
    last_update.append(copy.deepcopy(item.grad))
    print(item.grad)

Output: None None None None None None None None

任何帮助,将不胜感激。

4

1 回答 1

0

您正在查看从 a 加载的值state_dict- 渐变未保存在那里。.grad在您致电之后backward()和之前尝试打印zero_grad()

于 2021-05-25T05:10:15.613 回答