我有一个使用 跟踪和保存的脚本torch.jit.save
,当我使用以下命令在 python 中加载它时:
net = torch.jit.load(r"script.pt")
net.to('cuda')
net.train()
print(net.training)
net.eval()
print(net.training)
我得到了预期的输出:
True
False
但是,当我使用 libtorch 和以下代码加载相同的文件时:
auto module = torch::jit::load("script.pt");
module.to(torch::kCUDA);
module.train();
std::cout << std::boolalpha << module.is_training() << std::endl;
module.eval();
std::cout << std::boolalpha << module.is_training() << std::endl;
输出是
true
true
进入该is_training()
函数,它会查找该training
属性,但它不存在,因此默认为 true
/// True if the module is in training mode.
bool is_training() const {
return attr("training", true).toBool();
}
使用此模块有效,我可以调用它forward
并处理数据,但我需要能够将训练模式设置为false
有什么想法吗?追踪时有什么需要做的吗?
我在带有 CUDA 10.2 的 Ubuntu 16.04 上使用 PyTorch 1.5.0