3

我需要在 C++ 中运行一个预训练的 pytorch nn 模型(在 python 中训练)来进行预测。

为此,我按照此处给出的关于如何在 c++ 中加载 pytorch 模型的说明进行操作:https ://pytorch.org/tutorials/advanced/cpp_export.html

但是,当我尝试按照教程第一步中所述通过跟踪获取 torch.jit.ScriptModule 时:

    traced_script_module =
        torch.jit.trace(model, (input_tensor_1, input_tensor_2))

它不是返回一个 torch.jit.ScriptModule,而是返回一个函数:

    print(type(traced_script_module))
    <type 'function'>

其中,当我运行时:

    traced_script_module.save("model.pt")

然后导致以下错误:

Traceback (most recent call last):
  File "serialize_model.py", line 60, in <module>
    traced_script_module.save("model.pt")
AttributeError: 'function' object has no attribute 'save'

关于我做错了什么的任何想法?

4

1 回答 1

2

感谢您询问Jatentaki。我在 Python 中使用 PyTorch 0.4,当我更新到 1.0 时它工作了。

于 2019-02-13T05:17:23.450 回答