我使用此代码将模型转换为脚本模型:
scripted_model = torch.jit.trace(detector.model, images).eval()
然后我打印 scripted_model。部分输出如下:
(base): DLA(
original_name=DLA
(base_layer): Sequential(
original_name=Sequential
(0): Conv2d(original_name=Conv2d)
(1): BatchNorm2d(original_name=BatchNorm2d)
(2): ReLU(original_name=ReLU)
)
(level0): Sequential(
original_name=Sequential
(0): Conv2d(original_name=Conv2d)
(1): BatchNorm2d(original_name=BatchNorm2d)
(2): ReLU(original_name=ReLU)
)
(level1): Sequential(
original_name=Sequential
(0): Conv2d(original_name=Conv2d)
(1): BatchNorm2d(original_name=BatchNorm2d)
(2): ReLU(original_name=ReLU)
)
(level2): Tree(
original_name=Tree
(tree1): BasicBlock(
original_name=BasicBlock
(conv1): Conv2d(original_name=Conv2d)
(bn1): BatchNorm2d(original_name=BatchNorm2d)
(relu): ReLU(original_name=ReLU)
(conv2): Conv2d(original_name=Conv2d)
(bn2): BatchNorm2d(original_name=BatchNorm2d)
)
(tree2): BasicBlock(
original_name=BasicBlock
(conv1): Conv2d(original_name=Conv2d)
(bn1): BatchNorm2d(original_name=BatchNorm2d)
(relu): ReLU(original_name=ReLU)
(conv2): Conv2d(original_name=Conv2d)
(bn2): BatchNorm2d(original_name=BatchNorm2d
)
(root): Root(
original_name=Root
(conv): Conv2d(original_name=Conv2d)
(bn): BatchNorm2d(original_name=BatchNorm2d)
(relu): ReLU(original_name=ReLU)
)
(downsample): MaxPool2d(original_name=MaxPool2d)
(project): Sequential(
original_name=Sequential
(0): Conv2d(original_name=Conv2d)
(1): BatchNorm2d(original_name=BatchNorm2d)
)
)
...
我只想获取运算符的输入大小,例如运算符的输入数(0): Conv2d(original_name=Conv2d)
。我打印了这个脚本模型的图形,输出如下:
%4770 : __torch__.torch.nn.modules.module.___torch_mangle_11.Module = prim::GetAttr[name="wh"](%self.1)
%4762 : __torch__.torch.nn.modules.module.___torch_mangle_15.Module = prim::GetAttr[name="tracking"](%self.1)
%4754 : __torch__.torch.nn.modules.module.___torch_mangle_23.Module = prim::GetAttr[name="rot"](%self.1)
%4746 : __torch__.torch.nn.modules.module.___torch_mangle_7.Module = prim::GetAttr[name="reg"](%self.1)
%4738 : __torch__.torch.nn.modules.module.___torch_mangle_3.Module = prim::GetAttr[name="hm"](%self.1)
%4730 : __torch__.torch.nn.modules.module.___torch_mangle_27.Module = prim::GetAttr[name="dim"](%self.1)
%4722 : __torch__.torch.nn.modules.module.___torch_mangle_19.Module = prim::GetAttr[name="dep"](%self.1)
%4714 : __torch__.torch.nn.modules.module.___torch_mangle_31.Module = prim::GetAttr[name="amodel_offset"](%self.1)
%4706 : __torch__.torch.nn.modules.module.___torch_mangle_289.Module = prim::GetAttr[name="ida_up"](%self.1)
%4645 : __torch__.torch.nn.modules.module.___torch_mangle_262.Module = prim::GetAttr[name="dla_up"](%self.1)
%4461 : __torch__.torch.nn.modules.module.___torch_mangle_180.Module = prim::GetAttr[name="base"](%self.1)
%5100 : (Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor) = prim::CallMethod[name="forward"](%4461, %input.1)
%5082 : Tensor, %5083 : Tensor, %5084 : Tensor, %5085 : Tensor, %5086 : Tensor, %5087 : Tensor, %5088 : Tensor, %5089 : Tensor = prim::TupleUnpack(%5100)
%5101 : (Tensor, Tensor, Tensor) = prim::CallMethod[name="forward"](%4645, %5082, %5083, %5084, %5085, %5086, %5087, %5088, %5089)
%5097 : Tensor, %5098 : Tensor, %5099 : Tensor = prim::TupleUnpack(%5101)
%3158 : None = prim::Constant()
我什至可以找到运营商的名字。如何获取脚本模型中特定运算符的输入大小?