0

我正在通过 TorchScript 跟踪导出 PyTorch 模型,但我遇到了问题。具体来说,我必须对张量大小执行一些操作,但是 JIT 编译器将变量形状硬编码为常量,从而与不同大小的张量兼容。

例如,创建类:

class Foo(nn.Module):
    """Toy class that plays with tensor shape to showcase tracing issue.

    It creates a new tensor with the same shape as the input one, except
    for the last dimension, which is doubled. This new tensor is filled
    based on the values of the input.
    """
    def __init__(self):
        nn.Module.__init__(self)

    def forward(self, x):
        new_shape = (x.shape[0], 2*x.shape[1])  # incriminated instruction
        x2 = torch.empty(size=new_shape)
        x2[:, ::2] = x
        x2[:, 1::2] = x + 1
        return x2

并运行测试代码:

x = torch.randn((3, 5))  # create example input

foo = Foo()
traced_foo = torch.jit.trace(foo, x)  # trace
print(traced_foo(x).shape)  # obviously this works
print(traced_foo(x[:, :4]).shape)  # but fails with a different shape!

我可以通过编写脚本来解决这个问题,但在这种情况下,我真的需要使用跟踪。此外,我认为追踪应该能够正确处理张量大小的操作。

4

1 回答 1

0

但在这种情况下,我真的需要使用跟踪

您可以在任何需要的地方torch.script自由混合。torch.jit例如,可以这样做:

import torch


class MySuperModel(torch.nn.Module):
    def __init__(self, *args, **kwargs):
        super().__init__()
        self.scripted = torch.jit.script(Foo(*args, **kwargs))
        self.traced = Bar(*args, **kwargs)

    def forward(self, data):
        return self.scripted(self.traced(data))

model = MySuperModel()
torch.jit.trace(model, (input1, input2))

您还可以根据形状移动部分功能以分离功能并用以下方式装饰它@torch.jit.script

@torch.jit.script
def _forward_impl(x):
    new_shape = (x.shape[0], 2*x.shape[1])  # incriminated instruction
    x2 = torch.empty(size=new_shape)
    x2[:, ::2] = x
    x2[:, 1::2] = x + 1
    return x2

class Foo(nn.Module):
    def forward(self, x):
        return _forward_impl(x)

没有别的办法script,因为它必须理解你的代码。通过跟踪,它仅记录您在张量上执行的操作,并且不了解依赖于数据(或数据形状)的控制流。

无论如何,这应该涵盖大多数情况,如果没有,您应该更具体。

于 2021-05-06T09:44:34.513 回答