从训练中获取参数:
unchanged_model_dict = train() #unchanged_model_dict is the result after training new_state_dict = OrderedDict() #initialization a dict to store the averaged result从文件中获取参数:
net.load_state_dict(torch.load('./save_parameters_B')) #load the file to be averaged model_dict_B = net.state_dict() #get the parameters from the file开始平均两个部分:
for key in unchanged_model_dict.keys(): new_state_dict[key] = (model_dict_B[key] + unchanged_model_dict[key])/2保存计算结果:
torch.save(new_state_dict,('./save_parameters_B'))
每次我再次加载文件并继续训练时,只能得到相同的文件中参数的损失'./save_parameters_B'不变,而不是new_state_dict. 我不知道为什么。