我正在通过 TorchScript 跟踪导出 PyTorch 模型,但我遇到了问题。具体来说,我必须对张量大小执行一些操作,但是 JIT 编译器将变量形状硬编码为常量,从而与不同大小的张量兼容。
例如,创建类:
class Foo(nn.Module):
"""Toy class that plays with tensor shape to showcase tracing issue.
It creates a new tensor with the same shape as the input one, except
for the last dimension, which is doubled. This new tensor is filled
based on the values of the input.
"""
def __init__(self):
nn.Module.__init__(self)
def forward(self, x):
new_shape = (x.shape[0], 2*x.shape[1]) # incriminated instruction
x2 = torch.empty(size=new_shape)
x2[:, ::2] = x
x2[:, 1::2] = x + 1
return x2
并运行测试代码:
x = torch.randn((3, 5)) # create example input
foo = Foo()
traced_foo = torch.jit.trace(foo, x) # trace
print(traced_foo(x).shape) # obviously this works
print(traced_foo(x[:, :4]).shape) # but fails with a different shape!
我可以通过编写脚本来解决这个问题,但在这种情况下,我真的需要使用跟踪。此外,我认为追踪应该能够正确处理张量大小的操作。