为什么添加装饰器“@torch.jit.script”会导致错误,而我可以在该模块上调用 torch.jit.script,例如这会失败:
import torch
@torch.jit.script
class MyCell(torch.nn.Module):
def __init__(self):
super(MyCell, self).__init__()
self.linear = torch.nn.Linear(4, 4)
def forward(self, x, h):
new_h = torch.tanh(self.linear(x) + h)
return new_h, new_h
my_cell = MyCell()
x, h = torch.rand(3, 4), torch.rand(3, 4)
traced_cell = torch.jit.script(my_cell, (x, h))
print(traced_cell)
traced_cell(x, h)
"C:\Users\Administrator\AppData\Local\Packages\PythonSoftwareFoundation.Python.3.8_qbz5n2kfra8p0\LocalCache\local-packages\Python38\site-packages\torch\jit\__init__.py", line 1262, in script
raise RuntimeError("Type '{}' cannot be compiled since it inherits"
RuntimeError: Type '<class '__main__.MyCell'>' cannot be compiled since it inherits from nn.Module, pass an instance instead
虽然以下代码运行良好:
class MyCell(torch.nn.Module):
def __init__(self):
super(MyCell, self).__init__()
self.linear = torch.nn.Linear(4, 4)
def forward(self, x, h):
new_h = torch.tanh(self.linear(x) + h)
return new_h, new_h
my_cell = MyCell()
x, h = torch.rand(3, 4), torch.rand(3, 4)
traced_cell = torch.jit.script(my_cell, (x, h))
print(traced_cell)
traced_cell(x, h)
这个问题也出现在PyTorch 论坛上。