如何准确找到 PyTorch 模型图中存在的节点,以及它们的输入?我试图torch._C.Graph
使用
scripted=torch.jit.script(MyModel().eval())
frozen_module = torch.jit.freeze(scripted)
print(frozen_module.inlined_graph)
给出了以下输出
graph(%self : __torch__.___torch_mangle_2.MyModel,
%x1.1 : Tensor,
%x2.1 : Tensor,
%x3.1 : Tensor):
%4 : Float(52229:1, 4:52229, requires_grad=0, device=cpu) = prim::Constant[value=<Tensor>]()
%5 : Float(10:1, 5:10, requires_grad=0, device=cpu) = prim::Constant[value=<Tensor>]()
%6 : int[] = prim::Constant[value=[0, 0]]()
%7 : int[] = prim::Constant[value=[2, 2]]()
%8 : int[] = prim::Constant[value=[1, 1]]()
%9 : int = prim::Constant[value=2]()
%10 : bool = prim::Constant[value=0]()
%11 : int = prim::Constant[value=1]() # test.py:39:34
%12 : int = prim::Constant[value=0]() # test.py:39:29
%13 : int = prim::Constant[value=-1]() # test.py:39:33
%self.classifier.bias : Float(4:1, requires_grad=0, device=cpu) = prim::Constant[value=0.001 * 2.8424 1.0601 -1.3229 4.2920 [ CPUFloatType{4} ]]()
%self.features3.0.bias : Float(5:1, requires_grad=0, device=cpu) = prim::Constant[value= 0.0111 -0.0702 0.1396 0.1691 0.1335 [ CPUFloatType{5} ]]()
%self.features2.0.bias : Float(3:1, requires_grad=0, device=cpu) = prim::Constant[value= 0.3314 0.0165 0.2588 [ CPUFloatType{3} ]]()
%self.features2.0.weight : Float(3:9, 1:9, 3:3, 3:1, requires_grad=0, device=cpu) = prim::Constant[value=<Tensor>]()
%self.features1.0.bias : Float(3:1, requires_grad=0, device=cpu) = prim::Constant[value=0.01 * 2.5380 -31.8947 -15.3462 [ CPUFloatType{3} ]]()
%self.features1.0.weight : Float(3:9, 1:9, 3:3, 3:1, requires_grad=0, device=cpu) = prim::Constant[value=<Tensor>]()
%input.4 : Tensor = aten::conv2d(%x1.1, %self.features1.0.weight, %self.features1.0.bias, %8, %8, %8, %11)
%input.6 : Tensor = aten::max_pool2d(%input.4, %7, %7, %6, %8, %10)
%x1.3 : Tensor = aten::relu(%input.6)
%input.7 : Tensor = aten::conv2d(%x2.1, %self.features2.0.weight, %self.features2.0.bias, %8, %8, %8, %11)
%input.8 : Tensor = aten::max_pool2d(%input.7, %7, %7, %6, %8, %10)
%x2.3 : Tensor = aten::relu(%input.8)
%26 : int = aten::dim(%x3.1)
%27 : bool = aten::eq(%26, %9)
%input.3 : Tensor = prim::If(%27)
block0():
%ret.2 : Tensor = aten::addmm(%self.features3.0.bias, %x3.1, %5, %11, %11)
-> (%ret.2)
block1():
%output.2 : Tensor = aten::matmul(%x3.1, %5)
%output.4 : Tensor = aten::add_(%output.2, %self.features3.0.bias, %11)
-> (%output.4)
%x3.3 : Tensor = aten::relu(%input.3)
%33 : int = aten::size(%x1.3, %12)
%34 : int[] = prim::ListConstruct(%33, %13)
%x1.6 : Tensor = aten::view(%x1.3, %34)
%36 : int = aten::size(%x2.3, %12)
%37 : int[] = prim::ListConstruct(%36, %13)
%x2.6 : Tensor = aten::view(%x2.3, %37)
%39 : int = aten::size(%x3.3, %12)
%40 : int[] = prim::ListConstruct(%39, %13)
%x3.6 : Tensor = aten::view(%x3.3, %40)
%42 : Tensor[] = prim::ListConstruct(%x1.6, %x2.6, %x3.6)
%x.1 : Tensor = aten::cat(%42, %11)
%44 : int = aten::dim(%x.1)
%45 : bool = aten::eq(%44, %9)
%x.3 : Tensor = prim::If(%45)
block0():
%ret.1 : Tensor = aten::addmm(%self.classifier.bias, %x.1, %4, %11, %11)
-> (%ret.1)
block1():
%output.1 : Tensor = aten::matmul(%x.1, %4)
%output.3 : Tensor = aten::add_(%output.1, %self.classifier.bias, %11)
-> (%output.3)
return (%x.3)
但是我无法迭代或找到其中存在的节点或它具有的输入的确切内容。建议是否有其他方式来执行上述操作。