0

如何准确找到 PyTorch 模型图中存在的节点,以及它们的输入?我试图torch._C.Graph使用


scripted=torch.jit.script(MyModel().eval())
frozen_module = torch.jit.freeze(scripted)
print(frozen_module.inlined_graph)

给出了以下输出

graph(%self : __torch__.___torch_mangle_2.MyModel,
      %x1.1 : Tensor,
      %x2.1 : Tensor,
      %x3.1 : Tensor):
  %4 : Float(52229:1, 4:52229, requires_grad=0, device=cpu) = prim::Constant[value=<Tensor>]()
  %5 : Float(10:1, 5:10, requires_grad=0, device=cpu) = prim::Constant[value=<Tensor>]()
  %6 : int[] = prim::Constant[value=[0, 0]]()
  %7 : int[] = prim::Constant[value=[2, 2]]()
  %8 : int[] = prim::Constant[value=[1, 1]]()
  %9 : int = prim::Constant[value=2]() 
  %10 : bool = prim::Constant[value=0]()
  %11 : int = prim::Constant[value=1]() # test.py:39:34
  %12 : int = prim::Constant[value=0]() # test.py:39:29
  %13 : int = prim::Constant[value=-1]() # test.py:39:33
  %self.classifier.bias : Float(4:1, requires_grad=0, device=cpu) = prim::Constant[value=0.001 *  2.8424  1.0601 -1.3229  4.2920 [ CPUFloatType{4} ]]()
  %self.features3.0.bias : Float(5:1, requires_grad=0, device=cpu) = prim::Constant[value= 0.0111 -0.0702  0.1396  0.1691  0.1335 [ CPUFloatType{5} ]]()
  %self.features2.0.bias : Float(3:1, requires_grad=0, device=cpu) = prim::Constant[value= 0.3314  0.0165  0.2588 [ CPUFloatType{3} ]]()
  %self.features2.0.weight : Float(3:9, 1:9, 3:3, 3:1, requires_grad=0, device=cpu) = prim::Constant[value=<Tensor>]()
  %self.features1.0.bias : Float(3:1, requires_grad=0, device=cpu) = prim::Constant[value=0.01 *  2.5380 -31.8947 -15.3462 [ CPUFloatType{3} ]]()
  %self.features1.0.weight : Float(3:9, 1:9, 3:3, 3:1, requires_grad=0, device=cpu) = prim::Constant[value=<Tensor>]()
  %input.4 : Tensor = aten::conv2d(%x1.1, %self.features1.0.weight, %self.features1.0.bias, %8, %8, %8, %11) 
  %input.6 : Tensor = aten::max_pool2d(%input.4, %7, %7, %6, %8, %10) 
  %x1.3 : Tensor = aten::relu(%input.6) 
  %input.7 : Tensor = aten::conv2d(%x2.1, %self.features2.0.weight, %self.features2.0.bias, %8, %8, %8, %11) 
  %input.8 : Tensor = aten::max_pool2d(%input.7, %7, %7, %6, %8, %10) 
  %x2.3 : Tensor = aten::relu(%input.8) 
  %26 : int = aten::dim(%x3.1) 
  %27 : bool = aten::eq(%26, %9) 
  %input.3 : Tensor = prim::If(%27) 
    block0():
      %ret.2 : Tensor = aten::addmm(%self.features3.0.bias, %x3.1, %5, %11, %11) 
      -> (%ret.2)
    block1():
      %output.2 : Tensor = aten::matmul(%x3.1, %5) 
      %output.4 : Tensor = aten::add_(%output.2, %self.features3.0.bias, %11) 
      -> (%output.4)
  %x3.3 : Tensor = aten::relu(%input.3) 
  %33 : int = aten::size(%x1.3, %12) 
  %34 : int[] = prim::ListConstruct(%33, %13)
  %x1.6 : Tensor = aten::view(%x1.3, %34) 
  %36 : int = aten::size(%x2.3, %12) 
  %37 : int[] = prim::ListConstruct(%36, %13)
  %x2.6 : Tensor = aten::view(%x2.3, %37) 
  %39 : int = aten::size(%x3.3, %12) 
  %40 : int[] = prim::ListConstruct(%39, %13)
  %x3.6 : Tensor = aten::view(%x3.3, %40)
  %42 : Tensor[] = prim::ListConstruct(%x1.6, %x2.6, %x3.6)
  %x.1 : Tensor = aten::cat(%42, %11) 
  %44 : int = aten::dim(%x.1)
  %45 : bool = aten::eq(%44, %9) 
  %x.3 : Tensor = prim::If(%45) 
    block0():
      %ret.1 : Tensor = aten::addmm(%self.classifier.bias, %x.1, %4, %11, %11) 
      -> (%ret.1)
    block1():
      %output.1 : Tensor = aten::matmul(%x.1, %4) 
      %output.3 : Tensor = aten::add_(%output.1, %self.classifier.bias, %11) 
      -> (%output.3)
  return (%x.3)

但是我无法迭代或找到其中存在的节点或它具有的输入的确切内容。建议是否有其他方式来执行上述操作。

4

0 回答 0