1

我正在尝试将 torchscript 模型转换为 ONNX 格式。我正在使用 Pytorch 1.8.0。但是,当我尝试运行以下代码时出现错误。

device = torch.device('cpu')

trained_model.eval()
if torch.cuda.is_available():
    trained_model.load_state_dict(torch.load('model.pt'))
else:
    trained_model.load_state_dict(torch.load('model.pt'),map_location=device)


dummy_input = Variable(torch.randn(1, 1, 1, 1, 1))

trained_model.eval()
torch.onnx.export(trained_model, dummy_input, "model.onnx",input_names = ['input'], output_names = ['output'])
 

错误是:

/usr/local/lib/python3.6/dist-packages/torch/serialization.py:589: UserWarning: 'torch.load' received a zip file that looks like a TorchScript archive dispatching to 'torch.jit.load' (call 'torch.jit.load' directly to silence this warning)
  " silence this warning)", UserWarning)
Traceback (most recent call last):
  File "script.py", line 131, in <module>
    trained_model.load_state_dict(torch.load('model.pt',map_location=device))
  File "/usr/local/lib/python3.6/dist-packages/torch/serialization.py", line 591, in load
    return torch.jit.load(opened_file)
  File "/usr/local/lib/python3.6/dist-packages/torch/jit/_serialization.py", line 164, in load
    cu, f.read(), map_location, _extra_files
RuntimeError: Could not run 'aten::empty_strided' with arguments from the 'CUDA' backend.This could be because the operator doesn't exist for this backend, or was omitted during the selective/custom build process (if using custom build). If you are a Facebook employee using PyTorch on mobile, please visit https://fburl.com/ptmfixes for possible resolutions. 'aten::empty_strided' is only available for these backends: [CPU, BackendSelect, Named, AutogradOther, AutogradCPU, AutogradCUDA, AutogradXLA, AutogradNestedTensor, UNKNOWN_TENSOR_TYPE_ID, AutogradPrivateUse1, AutogradPrivateUse2, AutogradPrivateUse3, Tracer, Autocast, Batched, VmapMode].

当我尝试用 torch.jit.load 替换 torch.load 时,出现以下错误:

  File "script.py", line 131, in <module>
    trained_model.load_state_dict(torch.jit.load('model.pt',map_location=device))
  File "/usr/local/lib/python3.6/dist-packages/torch/nn/modules/module.py", line 1196, in load_state_dict
    state_dict = state_dict.copy()
  File "/usr/local/lib/python3.6/dist-packages/torch/jit/_script.py", line 561, in __getattr__
    return super(RecursiveScriptModule, self).__getattr__(attr)
  File "/usr/local/lib/python3.6/dist-packages/torch/jit/_script.py", line 291, in __getattr__
    return super(ScriptModule, self).__getattr__(attr)
  File "/usr/local/lib/python3.6/dist-packages/torch/nn/modules/module.py", line 948, in __getattr__
    type(self).__name__, name))
AttributeError: 'RecursiveScriptModule' object has no attribute 'copy'

基本上如何加载模型中保存的权重以将它们导出为 onnx 格式?Pytorch 文档涵盖了普通的 pytorch 模型,但是当我尝试为 torxhscript 模型执行相同的步骤时,它失败了。

4

0 回答 0