0

这段代码编译得很好:

import torch
import torch.nn             as nn

class Foo(nn.Module):
    def __init__(self):
        super(Foo, self).__init__()
        self.x = 0

    def forward(self, X):
        X      *= self.x
        self.x += 1
        return X

# @torch.jit.script
def bar(f: Foo):
    return f.x

但是,如果我取消注释该# @torch.jit.script行,我会收到此错误:

Traceback (most recent call last):
  File "test1.py", line 18, in <module>
    def bar(f: Foo):
  File "/anaconda3/envs/pytorch/lib/python3.7/site-packages/torch/jit/__init__.py", line 1103, in script
    fn = torch._C._jit_script_compile(qualified_name, ast, _rcb, get_default_args(obj))
RuntimeError:
Unknown type name 'Foo':
at test1.py:18:12
@torch.jit.script
def bar(f: Foo):
           ~~~ <--- HERE
    return f.x

如果我将类型注释更改为int

@torch.jit.script
# def bar(f: Foo):
#     return f.x
def bar(f: int):
    return f

然后编译再次工作。

有谁知道我需要做什么,以允许我的自定义类定义在类型注释中用于位于torch.jit.script装饰器下的函数的参数?

4

1 回答 1

1

只有此处文档中的类型列表可以用作函数的参数:

https://pytorch.org/docs/stable/jit_language_reference.html#supported-type

nn.Modules 在 TorchScript 中进行了一些特殊处理以使其工作,但目前不支持将它们作为参数。

于 2020-06-29T22:57:30.010 回答