2

为什么添加装饰器“@torch.jit.script”会导致错误,而我可以在该模块上调用 torch.jit.script,例如这会失败:

import torch
    
@torch.jit.script
class MyCell(torch.nn.Module):
    def __init__(self):
        super(MyCell, self).__init__()
        self.linear = torch.nn.Linear(4, 4)

    def forward(self, x, h):
        new_h = torch.tanh(self.linear(x) + h)
        return new_h, new_h
    
my_cell = MyCell()
x, h = torch.rand(3, 4), torch.rand(3, 4)
traced_cell = torch.jit.script(my_cell, (x, h))
print(traced_cell)
traced_cell(x, h)
"C:\Users\Administrator\AppData\Local\Packages\PythonSoftwareFoundation.Python.3.8_qbz5n2kfra8p0\LocalCache\local-packages\Python38\site-packages\torch\jit\__init__.py", line 1262, in script
    raise RuntimeError("Type '{}' cannot be compiled since it inherits"
RuntimeError: Type '<class '__main__.MyCell'>' cannot be compiled since it inherits from nn.Module, pass an instance instead

虽然以下代码运行良好:

class MyCell(torch.nn.Module):
    def __init__(self):
        super(MyCell, self).__init__()
        self.linear = torch.nn.Linear(4, 4)
    
    def forward(self, x, h):
        new_h = torch.tanh(self.linear(x) + h)
        return new_h, new_h
    
my_cell = MyCell()
x, h = torch.rand(3, 4), torch.rand(3, 4)
traced_cell = torch.jit.script(my_cell, (x, h))
print(traced_cell)
traced_cell(x, h)

这个问题也出现在PyTorch 论坛上

4

1 回答 1

2

你的错误原因在这里,这个要点:

不支持继承或任何其他多态策略,除了从对象继承以指定新样式类。

此外,如顶部所述:

TorchScript 类支持是实验性的。目前它最适合简单的类似记录的类型(想想带有方法的 NamedTuple)。

目前,它的目的是用于简单的Python类(请参阅我提供的链接中的其他点)和函数,请参阅我提供的链接以获取更多信息。

您还可以检查torch.jit.script源代码以更好地了解它的工作原理。

从表面上看,当你传递一个实例时,所有attributes应该保留的都被递归解析(source)。您可以继续使用此功能(评论很多,但答案太长,请参见此处),尽管这种情况的确切原因(以及为什么以这种方式设计)超出了我的知识范围(因此希望有人在torch.jit' s 内部工作原理将更多地谈论它)。

于 2020-08-13T19:56:05.323 回答