我nn.Module
用 torch.jit.script 转换并以 .pt 格式保存。该模块中的 forward 函数有一个 Int 参数。
def forward(self, x: Tensor, id : int) -> Tensor:
print(id)
x = self._forward(x)
return x
当我在 C++ 中加载模块时,我像这样传递张量,
std::vector<torch::jit::IValue> inputs;
inputs.push_back(torch::ones({1, 3, 224, 224}));
at::Tensor output = module.forward(inputs).toTensor();
但是我应该如何为 Int 编写它?我应该使用哪个结构?