1

我尝试加载torchscript模型C++,但出现错误

RuntimeError:预期的标量类型 Double,但发现 Float。

完整的输出是

$ ./example-app ../turb_nn.pt
1 2 3 4 5
 1  2  3  4  5
[ CPUDoubleType{1,5} ]
ok
run forward
terminate called after throwing an instance of 'std::runtime_error'
  what():  The following operation failed in the TorchScript interpreter.
Traceback of TorchScript (most recent call last):
  File "<string>", line 3, in forward

      def addmm(self: Tensor, mat1: Tensor, mat2: Tensor, beta: number = 1.0, alpha: number = 1.0):
          return self + mat1.mm(mat2)
                        ~~~~~~~ <--- HERE

      def batch_norm(input : Tensor, running_mean : Optional[Tensor], running_var : Optional[Tensor], training : bool, momentum : float, eps : float) -> Tensor:
RuntimeError: expected scalar type Double but found Float

Aborted (core dumped)

输入数据实际上是双精度类型。我尝试在此处.double()使用以下但无法运行。我的代码:make

#include <torch/script.h> // One-stop header.

#include <iostream>
#include <memory>

int main(int argc, const char* argv[]) {
  std::vector<double> inputdata({ 1.0, 2.0, 3.0, 4.0, 5.0 });
  std::cout << inputdata << std::endl;

  auto opts = torch::TensorOptions().dtype(torch::kDouble);
  torch::Tensor input = torch::from_blob(inputdata.data(), {1, (int)inputdata.size()}, opts).to(torch::kDouble);
  std::cout << input << std::endl;

  if (argc != 2) {
    std::cerr << "usage: example-app <path-to-exported-script-module>\n";
    return -1;
  }

  torch::jit::script::Module module;
  try {
    // Deserialize the ScriptModule from a file using torch::jit::load().
    module = torch::jit::load(argv[1]);
  }
  catch (const c10::Error& e) {
    std::cerr << "error loading the model\n";
    return -1;
  }

  std::cout << "ok" << std::endl;

    
  std::vector<torch::jit::IValue> input_invar;
  input_invar.push_back(input);
  std::cout << "run forward" << std::endl;
  at::Tensor output = module.forward(input_invar).toTensor();

  std::vector<double> out_vect(output.data_ptr<double>(), output.data_ptr<double>() + output.numel());

  std::cout << out_vect << std::endl;
}

请给我一些建议好吗?

4

0 回答 0