1

我有一个使用 跟踪和保存的脚本torch.jit.save,当我使用以下命令在 python 中加载它时:

  net = torch.jit.load(r"script.pt")
  net.to('cuda')
  net.train()
  print(net.training)
  net.eval()
  print(net.training)

我得到了预期的输出:

True
False

但是,当我使用 libtorch 和以下代码加载相同的文件时:

  auto module = torch::jit::load("script.pt");
  module.to(torch::kCUDA);
  module.train();
  std::cout << std::boolalpha << module.is_training() << std::endl;
  module.eval();
  std::cout << std::boolalpha << module.is_training() << std::endl;

输出是

true
true

进入该is_training()函数,它会查找该training属性,但它不存在,因此默认为 true

  /// True if the module is in training mode.
  bool is_training() const {
    return attr("training", true).toBool();
  }

使用此模块有效,我可以调用它forward并处理数据,但我需要能够将训练模式设置为false

有什么想法吗?追踪时有什么需要做的吗?

我在带有 CUDA 10.2 的 Ubuntu 16.04 上使用 PyTorch 1.5.0

4

0 回答 0