0
  1. 从训练中获取参数:

     unchanged_model_dict = train()  #unchanged_model_dict is the result after training
     new_state_dict = OrderedDict()  #initialization a dict to store the averaged result
    
  2. 从文件中获取参数:

     net.load_state_dict(torch.load('./save_parameters_B'))  #load the file to be averaged
     model_dict_B = net.state_dict()  #get the parameters from the file
    
  3. 开始平均两个部分:

     for key in unchanged_model_dict.keys():  
         new_state_dict[key] = (model_dict_B[key] + unchanged_model_dict[key])/2
    
  4. 保存计算结果:

     torch.save(new_state_dict,('./save_parameters_B'))  
    

每次我再次加载文件并继续训练时,只能得到相同的文件中参数的损失'./save_parameters_B'不变,而不是new_state_dict. 我不知道为什么。

4

0 回答 0