我正在尝试通过脚本将 PyTorch 模型导出到 TorchScript,但我被卡住了。我创建了一个玩具类来展示这个问题:
import torch
from torch import nn
class SadModule(nn.Module):
"""Takes a (*, 2) input and runs it thorugh a linear layer. Can optionally
use a skip connection. The usage of the skip connection or not is an
architectural choice.
"""
def __init__(self, use_skip: bool):
nn.Module.__init__(self)
self.use_skip = use_skip
self.layer = nn.Linear(2, 2)
def forward(self, x):
if self.use_skip:
x_input = x
x = self.layer(x)
if self.use_skip:
x = x + x_input
return x
它基本上只包含一个线性层和一个可选的跳过连接。如果我尝试使用脚本编写模型
mod1 = SadModule(False)
scripted_mod1 = torch.jit.script(mod)
我收到以下错误:
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
<ipython-input-10-a7ebc7af32c7> in <module>
----> 1 scripted_mod1 = torch.jit.script(mod)
~/Software/miniconda3/envs/pytorch3d/lib/python3.8/site-packages/torch/jit/_script.py in script(obj, optimize, _frames_up, _rcb)
895
896 if isinstance(obj, torch.nn.Module):
--> 897 return torch.jit._recursive.create_script_module(
898 obj, torch.jit._recursive.infer_methods_to_compile
899 )
~/Software/miniconda3/envs/pytorch3d/lib/python3.8/site-packages/torch/jit/_recursive.py in create_script_module(nn_module, stubs_fn, share_types)
350 check_module_initialized(nn_module)
351 concrete_type = get_module_concrete_type(nn_module, share_types)
--> 352 return create_script_module_impl(nn_module, concrete_type, stubs_fn)
353
354 def create_script_module_impl(nn_module, concrete_type, stubs_fn):
~/Software/miniconda3/envs/pytorch3d/lib/python3.8/site-packages/torch/jit/_recursive.py in create_script_module_impl(nn_module, concrete_type, stubs_fn)
408 # Compile methods if necessary
409 if concrete_type not in concrete_type_store.methods_compiled:
--> 410 create_methods_and_properties_from_stubs(concrete_type, method_stubs, property_stubs)
411 torch._C._run_emit_module_hook(cpp_module)
412 concrete_type_store.methods_compiled.add(concrete_type)
~/Software/miniconda3/envs/pytorch3d/lib/python3.8/site-packages/torch/jit/_recursive.py in create_methods_and_properties_from_stubs(concrete_type, method_stubs, property_stubs)
302 property_rcbs = [p.resolution_callback for p in property_stubs]
303
--> 304 concrete_type._create_methods_and_properties(property_defs, property_rcbs, method_defs, method_rcbs, method_defaults)
305
306
RuntimeError:
x_input is not defined in the false branch:
File "<ipython-input-7-d08ed7ff42ec>", line 12
def forward(self, x):
if self.use_skip:
~~~~~~~~~~~~~~~~~
x_input = x
~~~~~~~~~~~ <--- HERE
x = self.layer(x)
if self.use_skip:
and was used here:
File "<ipython-input-7-d08ed7ff42ec>", line 16
x = self.layer(x)
if self.use_skip:
x = x + x_input
~~~~~~~ <--- HERE
return x
因此,基本上 TorchScript 无法识别出任何一条语句mod1
的True
分支if
都不会被使用。此外,如果我们创建一个实际使用跳过连接的实例,
mod2 = SadModule(True)
scripted_mod2 = torch.jit.script(mod2)
我们会得到另一个错误:
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
<ipython-input-21-b5ca61d8aa73> in <module>
----> 1 scripted_mod2 = torch.jit.script(mod2)
~/Software/miniconda3/envs/pytorch3d/lib/python3.8/site-packages/torch/jit/_script.py in script(obj, optimize, _frames_up, _rcb)
895
896 if isinstance(obj, torch.nn.Module):
--> 897 return torch.jit._recursive.create_script_module(
898 obj, torch.jit._recursive.infer_methods_to_compile
899 )
~/Software/miniconda3/envs/pytorch3d/lib/python3.8/site-packages/torch/jit/_recursive.py in create_script_module(nn_module, stubs_fn, share_types)
350 check_module_initialized(nn_module)
351 concrete_type = get_module_concrete_type(nn_module, share_types)
--> 352 return create_script_module_impl(nn_module, concrete_type, stubs_fn)
353
354 def create_script_module_impl(nn_module, concrete_type, stubs_fn):
~/Software/miniconda3/envs/pytorch3d/lib/python3.8/site-packages/torch/jit/_recursive.py in create_script_module_impl(nn_module, concrete_type, stubs_fn)
408 # Compile methods if necessary
409 if concrete_type not in concrete_type_store.methods_compiled:
--> 410 create_methods_and_properties_from_stubs(concrete_type, method_stubs, property_stubs)
411 torch._C._run_emit_module_hook(cpp_module)
412 concrete_type_store.methods_compiled.add(concrete_type)
~/Software/miniconda3/envs/pytorch3d/lib/python3.8/site-packages/torch/jit/_recursive.py in create_methods_and_properties_from_stubs(concrete_type, method_stubs, property_stubs)
302 property_rcbs = [p.resolution_callback for p in property_stubs]
303
--> 304 concrete_type._create_methods_and_properties(property_defs, property_rcbs, method_defs, method_rcbs, method_defaults)
305
306
RuntimeError:
x_input is not defined in the false branch:
File "<ipython-input-18-ac8b9713c789>", line 17
def forward(self, x):
if self.use_skip:
~~~~~~~~~~~~~~~~~
x_input = x
~~~~~~~~~~~ <--- HERE
x = self.layer(x)
if self.use_skip:
and was used here:
File "<ipython-input-18-ac8b9713c789>", line 21
x = self.layer(x)
if self.use_skip:
x = x + x_input
~~~~~~~ <--- HERE
return x
因此,在这种情况下,TorchScript 不理解两个if
s 都将始终为真,并且实际上x_input
已明确定义。
为了避免这个问题,我可以将这个类分成两个子类,如下所示:
class SadModuleNoSkip(nn.Module):
"""Takes a (*, 2) input and runs it thorugh a linear layer. Can optionally
use a skip connection. The usage of the skip connection or not is an
architectural choice.
"""
def __init__(self):
nn.Module.__init__(self)
self.layer = nn.Linear(2, 2)
def forward(self, x):
x = self.layer(x)
return x
class SadModuleSkip(nn.Module):
"""Takes a (*, 2) input and runs it thorugh a linear layer. Can optionally
use a skip connection. The usage of the skip connection or not is an
architectural choice.
"""
def __init__(self):
nn.Module.__init__(self)
self.layer = nn.Linear(2, 2)
def forward(self, x):
x_input = x
x = self.layer(x)
x = x + x_input
return x
但是,我正在处理一个庞大的代码库,我必须为许多类重复该过程,这既费时又可能引入错误。此外,我正在处理的模块通常是巨大的卷积网络,而if
s 只是控制额外批量标准化的存在。在我看来,除了单个批次规范层之外,99% 的块中必须具有相同的类似乎是不可取的。
有没有办法可以帮助 TorchScript 处理分支?
编辑:添加了一个最小可行示例。
更新use_skip
:即使我将提示输入为常量也不起作用
from typing import Final
class SadModule(nn.Module):
use_skip: Final[bool]
...