我需要在 C++ 中运行一个预训练的 pytorch nn 模型(在 python 中训练)来进行预测。
为此,我按照此处给出的关于如何在 c++ 中加载 pytorch 模型的说明进行操作:https ://pytorch.org/tutorials/advanced/cpp_export.html
但是,当我尝试按照教程第一步中所述通过跟踪获取 torch.jit.ScriptModule 时:
traced_script_module =
torch.jit.trace(model, (input_tensor_1, input_tensor_2))
它不是返回一个 torch.jit.ScriptModule,而是返回一个函数:
print(type(traced_script_module))
<type 'function'>
其中,当我运行时:
traced_script_module.save("model.pt")
然后导致以下错误:
Traceback (most recent call last):
File "serialize_model.py", line 60, in <module>
traced_script_module.save("model.pt")
AttributeError: 'function' object has no attribute 'save'
关于我做错了什么的任何想法?