1

nn.Module用 torch.jit.script 转换并以 .pt 格式保存。该模块中的 forward 函数有一个 Int 参数。

    def forward(self, x: Tensor, id : int) -> Tensor:
        print(id)
        x = self._forward(x)
        return x

当我在 C++ 中加载模块时,我像这样传递张量,

std::vector<torch::jit::IValue> inputs;
inputs.push_back(torch::ones({1, 3, 224, 224}));
at::Tensor output = module.forward(inputs).toTensor();

但是我应该如何为 Int 编写它?我应该使用哪个结构?

4

1 回答 1

0

我的前锋:

def forward(self, t0, t1, i):
    ...

如果您尝试输入 an intforward()则在将模型导出到 torchscript 时会得到以下信息jit.trace

Type 'Tuple[Tensor, Tensor, int]' cannot be traced. Only Tensors and (possibly nested) Lists, Dicts, and Tuples of Tensors can be traced

因此,您需要将int类型转换tensor为例如

int a = 0;
std::vector<torch::jit::IValue> inputs;
inputs.push_back(torch::ones({1, 3, 640, 256}));
inputs.push_back(torch::ones({1, 3, 640, 256}));
inputs.push_back(torch::ones({1, a}));
at::Tensor output = module.forward(inputs).toTensor();
于 2022-01-21T12:51:27.137 回答