就我而言,PyTorch 中没有 aSplitTable
或 a之类的东西。SelectTable
也就是说,您可以在单个架构中连接任意数量的模块或块,并且您可以使用此属性来检索某个层的输出。让我们用一个简单的例子更清楚地说明这一点。
假设我想构建一个简单的两层 MLP 并检索每一层的输出。我可以构建一个自定义class
继承自nn.Module
:
class MyMLP(nn.Module):
def __init__(self, in_channels, out_channels_1, out_channels_2):
# first of all, calling base class constructor
super().__init__()
# now I can build my modular network
self.block1 = nn.Linear(in_channels, out_channels_1)
self.block2 = nn.Linear(out_channels_1, out_channels_2)
# you MUST implement a forward(input) method whenever inheriting from nn.Module
def forward(x):
# first_out will now be your output of the first block
first_out = self.block1(x)
x = self.block2(first_out)
# by returning both x and first_out, you can now access the first layer's output
return x, first_out
在您的主文件中,您现在可以声明自定义架构并使用它:
from myFile import MyMLP
import numpy as np
in_ch = out_ch_1 = out_ch_2 = 64
# some fake input instance
x = np.random.rand(in_ch)
my_mlp = MyMLP(in_ch, out_ch_1, out_ch_2)
# get your outputs
final_out, first_layer_out = my_mlp(x)
此外,您可以在更复杂的模型定义中连接两个 MyMLP,并以类似的方式检索每个 MyMLP 的输出。我希望这足以澄清,但如果您有更多问题,请随时提问,因为我可能遗漏了一些东西。