1

我在 中训练了一个QAT基于(量化感知训练)的模型Pytorch,训练进行得很顺利。然而,当我尝试将权重加载到融合模型中并在更宽的数据集上运行测试时,我遇到了很多错误:

(base) marian@u04-2:/mnt/s3user/Pytorch_Retinaface_quantized# python test_widerface.py --trained_model ./weights/mobilenet0.25_Final_quantized.pth --network mobile0.25layers:  
Loading pretrained model from ./weights/mobilenet0.25_Final_quantized.pth
remove prefix 'module.'
Missing keys:235
Unused checkpoint keys:171
Used keys:65
Traceback (most recent call last):
  File "/root/.vscode/extensions/ms-python.python-2020.1.58038/pythonFiles/ptvsd_launcher.py", line 43, in <module>
    main(ptvsdArgs)
  File "/root/.vscode/extensions/ms-python.python-2020.1.58038/pythonFiles/lib/python/old_ptvsd/ptvsd/__main__.py", line 432, in main
    run()
  File "/root/.vscode/extensions/ms-python.python-2020.1.58038/pythonFiles/lib/python/old_ptvsd/ptvsd/__main__.py", line 316, in run_file
    runpy.run_path(target, run_name='__main__')
  File "/root/anaconda3/lib/python3.7/runpy.py", line 263, in run_path
    pkg_name=pkg_name, script_name=fname)
  File "/root/anaconda3/lib/python3.7/runpy.py", line 96, in _run_module_code
    mod_name, mod_spec, pkg_name, script_name)
  File "/root/anaconda3/lib/python3.7/runpy.py", line 85, in _run_code
    exec(code, run_globals)
  File "/mnt/f3user/Pytorch_Retinaface_quantized/test_widerface.py", line 114, in <module>
    net = load_model(net, args.trained_model, args.cpu)
  File "/mnt/f3user/Pytorch_Retinaface_quantized/test_widerface.py", line 95, in load_model
    model.load_state_dict(pretrained_dict, strict=False)
  File "/root/anaconda3/lib/python3.7/site-packages/torch/nn/modules/module.py", line 830, in load_state_dict
    self.__class__.__name__, "\n\t".join(error_msgs)))
RuntimeError: Error(s) in loading state_dict for RetinaFace:
        While copying the parameter named "ssh1.conv3X3.0.weight", whose dimensions in the model are torch.Size([32, 64, 3, 3]) and whose dimensions in the checkpoint are torch.Size([32, 64, 3, 3]).
        While copying the parameter named "ssh1.conv5X5_2.0.weight", whose dimensions in the model are torch.Size([16, 16, 3, 3]) and whose dimensions in the checkpoint are torch.Size([16, 16, 3, 3]).
        While copying the parameter named "ssh1.conv7x7_3.0.weight", whose dimensions in the model are torch.Size([16, 16, 3, 3]) and whose dimensions in the checkpoint are torch.Size([16, 16, 3, 3]).
        While copying the parameter named "ssh2.conv3X3.0.weight", whose dimensions in the model are torch.Size([32, 64, 3, 3]) and whose dimensions in the checkpoint are torch.Size([32, 64, 3, 3]).
        While copying the parameter named "ssh2.conv5X5_2.0.weight", whose dimensions in the model are torch.Size([16, 16, 3, 3]) and whose dimensions in the checkpoint are torch.Size([16, 16, 3, 3]).
.....

完整列表可以在这里找到。
基本上找不到权重。加上融合模型中缺少的比例和零点。

如果重要,以下代码段是用于训练和保存模型的实际训练循环:

if __name__ == '__main__':
    # train()
    ...
    net = RetinaFace(cfg=cfg)
    print("Printing net...")
    print(net)

    net.fuse_model()
    ...

    net.qconfig = torch.quantization.get_default_qat_qconfig('fbgemm')
    torch.quantization.prepare_qat(net, inplace=True)
    print(f'quantization preparation done.')

    ... 

    quantized_model = net 
    for i in range(max_epoch):
        net = net.to(device)
        train_one_epoch(net, data_loader, optimizer, criterion, cfg, gamma, i, step_index, device)
        if i in stepvalues:
            step_index += 1
        if i > 3 :
            net.apply(torch.quantization.disable_observer)
        if i > 2 :
            net.apply(torch.nn.intrinsic.qat.freeze_bn_stats)
        net=net.cpu()
        quantized_model = torch.quantization.convert(net.eval(), inplace=False)
        quantized_model.eval()
        # evaluate on test set ?!

    torch.save(net.state_dict(), save_folder + cfg['name'] + '_Final.pth')
    torch.save(quantized_model.state_dict(), save_folder + cfg['name'] + '_Final_quantized.pth')
    #torch.jit.save(torch.jit.script(quantized_model), save_folder + cfg['name'] + '_Final_quantized_jit.pth')


用于测试test_widerface.py使用的可以在此处访问您可以在此处
查看密钥

为什么会这样?这应该如何处理?

更新

我检查了名称,并创建了一个新的 state_dict 字典,并使用下面的代码片段插入了检查点和模型中的 112 个键:

new_state_dict  = {}
checkpoint_state_dict = torch.load(checkpoint_path, map_location=lambda storage, loc: storage) 
for (ck, cp) in checkpoint_state_dict.items():
    for (mk, mp) in model.state_dict().items():
        kname,kext = os.path.splitext(ck)
        mname,mext = os.path.splitext(mk)
        # check the two parameter and see if they are the same
        # then use models key naming scheme and use checkpoints weights
        if kname+kext == mname+mext or kname+'.0'+kext == mname+mext:
            new_state_dict[mname+mext] = cp 
        else: 
             if kext in ('.scale','.zero_point'):
                 new_state_dict[ck] = cp

然后使用这个新的 state_dict!但是我得到了完全相同的错误!意思是这样的错误:

RuntimeError: Error(s) in loading state_dict for RetinaFace:
        While copying the parameter named "ssh1.conv3X3.0.weight", whose dimensions in the model are torch.Size([32, 64, 3, 3]) and whose dimensions in the checkpoint are torch.Size([32, 64, 3, 3]).

这真的很令人沮丧,并且没有关于此的文档!我在这里完全一无所知。

4

1 回答 1

0

我终于找到了原因。格式为 的错误消息:

复制名为“xxx.weight”的参数时,模型中的尺寸为torch.Size([yyy]),检查点中的尺寸为torch.Size([yyy])。

实际上是通用消息,仅在复制相关参数时发生异常时返回。

Pytorch 开发人员可以轻松地将实际的异常参数添加到这个虚假但无用的消息中,因此它实际上可以帮助更好地调试手头的问题。无论如何,看看顺便说一句的异常:

"copy_" not implemented for \'QInt8' 

您现在将知道实际问题是什么!

于 2020-01-28T08:44:00.327 回答