7

当我在 PyTorch 中使用预定义模块时,我通常可以很容易地访问它的权重。但是,如果我先包装模块,如何访问它们nn.Sequential()?rg:

class My_Model_1(nn.Module):
    def __init__(self,D_in,D_out):
        super(My_Model_1, self).__init__()
        self.layer = nn.Linear(D_in,D_out)
    def forward(self,x):
        out = self.layer(x)
        return out

class My_Model_2(nn.Module):
    def __init__(self,D_in,D_out):
        super(My_Model_2, self).__init__()
        self.layer = nn.Sequential(nn.Linear(D_in,D_out))
    def forward(self,x):
        out = self.layer(x)
        return out

model_1 = My_Model_1(10,10)
print(model_1.layer.weight)
model_2 = My_Model_2(10,10)

我现在如何打印权重? model_2.layer.0.weight不起作用。

4

3 回答 3

12

访问权重的一种简单方法是使用state_dict()模型的 。

这应该适用于您的情况:

for k, v in model_2.state_dict().iteritems():
    print("Layer {}".format(k))
    print(v)

另一种选择是获取modules()迭代器。如果您事先知道图层的类型,这也应该有效:

for layer in model_2.modules():
   if isinstance(layer, nn.Linear):
        print(layer.weight)
于 2017-06-01T19:00:50.730 回答
10

PyTorch 论坛,这是推荐的方式:

model_2.layer[0].weight
于 2017-06-04T10:21:34.143 回答
0

您可以使用以下名称访问模块_modules

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()

        self.conv1 = nn.Conv2d(3, 3, 3)

    def forward(self, input):
        return self.conv1(input)

model = Net()
print(model._modules['conv1'])
于 2019-10-24T21:46:52.327 回答