0

我正在尝试通过脚本将 PyTorch 模型导出到 TorchScript,但我被卡住了。我创建了一个玩具类来展示这个问题:

import torch
from torch import nn


class SadModule(nn.Module):
    """Takes a (*, 2) input and runs it thorugh a linear layer. Can optionally
    use a skip connection. The usage of the skip connection or not is an
    architectural choice.

    """
    def __init__(self, use_skip: bool):
        nn.Module.__init__(self)
        self.use_skip = use_skip
        self.layer = nn.Linear(2, 2)

    def forward(self, x):
        if self.use_skip:
            x_input = x
        x = self.layer(x)
        if self.use_skip:
            x = x + x_input
        return x

它基本上只包含一个线性层和一个可选的跳过连接。如果我尝试使用脚本编写模型

mod1 = SadModule(False)
scripted_mod1 = torch.jit.script(mod)

我收到以下错误:

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-10-a7ebc7af32c7> in <module>
----> 1 scripted_mod1 = torch.jit.script(mod)

~/Software/miniconda3/envs/pytorch3d/lib/python3.8/site-packages/torch/jit/_script.py in script(obj, optimize, _frames_up, _rcb)
    895
    896     if isinstance(obj, torch.nn.Module):
--> 897         return torch.jit._recursive.create_script_module(
    898             obj, torch.jit._recursive.infer_methods_to_compile
    899         )

~/Software/miniconda3/envs/pytorch3d/lib/python3.8/site-packages/torch/jit/_recursive.py in create_script_module(nn_module, stubs_fn, share_types)
    350     check_module_initialized(nn_module)
    351     concrete_type = get_module_concrete_type(nn_module, share_types)
--> 352     return create_script_module_impl(nn_module, concrete_type, stubs_fn)
    353
    354 def create_script_module_impl(nn_module, concrete_type, stubs_fn):

~/Software/miniconda3/envs/pytorch3d/lib/python3.8/site-packages/torch/jit/_recursive.py in create_script_module_impl(nn_module, concrete_type, stubs_fn)
    408     # Compile methods if necessary
    409     if concrete_type not in concrete_type_store.methods_compiled:
--> 410         create_methods_and_properties_from_stubs(concrete_type, method_stubs, property_stubs)
    411         torch._C._run_emit_module_hook(cpp_module)
    412         concrete_type_store.methods_compiled.add(concrete_type)

~/Software/miniconda3/envs/pytorch3d/lib/python3.8/site-packages/torch/jit/_recursive.py in create_methods_and_properties_from_stubs(concrete_type, method_stubs, property_stubs)
    302     property_rcbs = [p.resolution_callback for p in property_stubs]
    303
--> 304     concrete_type._create_methods_and_properties(property_defs, property_rcbs, method_defs, method_rcbs, method_defaults)
    305
    306

RuntimeError:

x_input is not defined in the false branch:
  File "<ipython-input-7-d08ed7ff42ec>", line 12
    def forward(self, x):
        if self.use_skip:
        ~~~~~~~~~~~~~~~~~
            x_input = x
            ~~~~~~~~~~~ <--- HERE
        x = self.layer(x)
        if self.use_skip:
and was used here:
  File "<ipython-input-7-d08ed7ff42ec>", line 16
        x = self.layer(x)
        if self.use_skip:
            x = x + x_input
                    ~~~~~~~ <--- HERE
        return x

因此,基本上 TorchScript 无法识别出任何一条语句mod1True分支if都不会被使用。此外,如果我们创建一个实际使用跳过连接的实例,

mod2 = SadModule(True)
scripted_mod2 = torch.jit.script(mod2)

我们会得到另一个错误:

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-21-b5ca61d8aa73> in <module>
----> 1 scripted_mod2 = torch.jit.script(mod2)

~/Software/miniconda3/envs/pytorch3d/lib/python3.8/site-packages/torch/jit/_script.py in script(obj, optimize, _frames_up, _rcb)
    895
    896     if isinstance(obj, torch.nn.Module):
--> 897         return torch.jit._recursive.create_script_module(
    898             obj, torch.jit._recursive.infer_methods_to_compile
    899         )

~/Software/miniconda3/envs/pytorch3d/lib/python3.8/site-packages/torch/jit/_recursive.py in create_script_module(nn_module, stubs_fn, share_types)
    350     check_module_initialized(nn_module)
    351     concrete_type = get_module_concrete_type(nn_module, share_types)
--> 352     return create_script_module_impl(nn_module, concrete_type, stubs_fn)
    353
    354 def create_script_module_impl(nn_module, concrete_type, stubs_fn):

~/Software/miniconda3/envs/pytorch3d/lib/python3.8/site-packages/torch/jit/_recursive.py in create_script_module_impl(nn_module, concrete_type, stubs_fn)
    408     # Compile methods if necessary
    409     if concrete_type not in concrete_type_store.methods_compiled:
--> 410         create_methods_and_properties_from_stubs(concrete_type, method_stubs, property_stubs)
    411         torch._C._run_emit_module_hook(cpp_module)
    412         concrete_type_store.methods_compiled.add(concrete_type)

~/Software/miniconda3/envs/pytorch3d/lib/python3.8/site-packages/torch/jit/_recursive.py in create_methods_and_properties_from_stubs(concrete_type, method_stubs, property_stubs)
    302     property_rcbs = [p.resolution_callback for p in property_stubs]
    303
--> 304     concrete_type._create_methods_and_properties(property_defs, property_rcbs, method_defs, method_rcbs, method_defaults)
    305
    306

RuntimeError:

x_input is not defined in the false branch:
  File "<ipython-input-18-ac8b9713c789>", line 17
    def forward(self, x):
        if self.use_skip:
        ~~~~~~~~~~~~~~~~~
            x_input = x
            ~~~~~~~~~~~ <--- HERE
        x = self.layer(x)
        if self.use_skip:
and was used here:
  File "<ipython-input-18-ac8b9713c789>", line 21
        x = self.layer(x)
        if self.use_skip:
            x = x + x_input
                    ~~~~~~~ <--- HERE
        return x

因此,在这种情况下,TorchScript 不理解两个ifs 都将始终为真,并且实际上x_input已明确定义。

为了避免这个问题,我可以将这个类分成两个子类,如下所示:

class SadModuleNoSkip(nn.Module):
    """Takes a (*, 2) input and runs it thorugh a linear layer. Can optionally
    use a skip connection. The usage of the skip connection or not is an
    architectural choice.

    """
    def __init__(self):
        nn.Module.__init__(self)
        self.layer = nn.Linear(2, 2)

    def forward(self, x):
        x = self.layer(x)
        return x

class SadModuleSkip(nn.Module):
    """Takes a (*, 2) input and runs it thorugh a linear layer. Can optionally
    use a skip connection. The usage of the skip connection or not is an
    architectural choice.

    """
    def __init__(self):
        nn.Module.__init__(self)
        self.layer = nn.Linear(2, 2)

    def forward(self, x):
        x_input = x
        x = self.layer(x)
        x = x + x_input
        return x

但是,我正在处理一个庞大的代码库,我必须为许多类重复该过程,这既费时又可能引入错误。此外,我正在处理的模块通常是巨大的卷积网络,而ifs 只是控制额外批量标准化的存在。在我看来,除了单个批次规范层之外,99% 的块中必须具有相同的类似乎是不可取的。

有没有办法可以帮助 TorchScript 处理分支?

编辑:添加了一个最小可行示例。

更新use_skip:即使我将提示输入为常量也不起作用

from typing import Final

class SadModule(nn.Module):
    use_skip: Final[bool]
    ...
4

1 回答 1

0

在 GitHub 上打开了一个问题。项目维护人员解释说,使用Final是要走的路。不过要小心,因为截至今天(2021 年 5 月 7 日),此功能仍在开发中(尽管处于最后阶段,请参阅此处了解功能跟踪器)。

尽管它还没有在官方版本中可用,但它存在于 PyTorch 的夜间版本中,因此您可以按照网站中的说明pytorch-nighly安装构建(向下滚动到Install PyTorch,然后选择Preview (Nightly),或者等待下一个版本。

对于几个月后阅读此答案的任何人,此功能应该已经集成到 PyTorch 的主要版本中。

于 2021-05-07T13:03:34.283 回答