我正在尝试使用 pytorch/libtorch 实现一个简单的神经网络。以下示例改编自libtorch cpp 前端教程。
#include <torch/torch.h>
struct DeepQImpl : torch::nn::Module {
DeepQImpl(size_t N)
: linear1(2,5),
linear2(5,3) {}
torch::Tensor forward(torch::Tensor x) const {
x = torch::tanh(linear1(x));
x = linear2(x);
return x;
}
torch::nn::Linear linear1, linear2;
};
TORCH_MODULE(DeepQ);
请注意,该函数forward
已声明const
。我正在编写的代码要求 NN 的评估是一个 const 函数,这对我来说似乎是合理的。但是,此代码无法编译。编译器抛出
错误:不匹配调用 '(const torch::nn::Linear) (at::Tensor&)'<br> x = linear1(x);
我已经找到了解决这个问题的方法,通过将图层定义为mutable
:
#include <torch/torch.h>
struct DeepQImpl : torch::nn::Module {
/* all the code */
mutable torch::nn:Linear linear1, linear2;
};
所以我的问题是
- 为什么在张量上应用层不是
const
- 正在使用
mutable
这种方法来解决这个问题,它安全吗?
我的直觉是,在前向传播中,层被组装成一个可用于反向传播的结构,需要一些写入操作。如果这是真的,那么问题就变成了如何在第一步(非const
)中组装层,然后在第二步(const
)中评估结构。