1

我正在尝试删除高效网络-pytorch 实现中的顶层。但是,如果我只是用我自己的全连接层替换最后_fc一层,正如作者在这个 github 评论中所建议的那样,我担心swish即使在这一层之后仍然有激活,而不是像我预期的那样没有任何东西。当我打印模型时,最后几行如下:

(_bn1): BatchNorm2d(1280, eps=0.001, momentum=0.010000000000000009, affine=True, track_running_stats=True)
    (_avg_pooling): AdaptiveAvgPool2d(output_size=1)
    (_dropout): Dropout(p=0.2, inplace=False)
    (_fc): Sequential(
      (0): Linear(in_features=1280, out_features=512, bias=True)
      (1): ReLU()
      (2): Dropout(p=0.25, inplace=False)
      (3): Linear(in_features=512, out_features=128, bias=True)
      (4): ReLU()
      (5): Dropout(p=0.25, inplace=False)
      (6): Linear(in_features=128, out_features=1, bias=True)
    )
    (_swish): MemoryEfficientSwish()
  )
)

_fc我的替换模块在哪里。

我希望做的是:

base_model = EfficientNet.from_pretrained('efficientnet-b3')
model = nn.Sequential(*list(base_model.children()[:-3]))

在我看来base_model.children(),从嵌套结构中将模型展平。但是,现在我似乎无法像使用虚拟输入一样使用模型,x=torch.randn(1,3,255,255)我收到错误:TypeError: forward() takes 1 positional argument but 2 were given.

应该注意的是,model[:2](x)有效,但不是model[:3](x)model[2]似乎是移动块。

这是一个带有上述代码的colab 笔记本。

4

1 回答 1

3

这是对print(net)实际操作的常见误解。

_swish后面有一个模块的事实_fc仅仅意味着前者是在后者之后注册的。您可以在代码中检查:

class EfficientNet(nn.Module):
    def __init__(self, blocks_args=None, global_params=None):

        # [...]

        # Final linear layer
        self._avg_pooling = nn.AdaptiveAvgPool2d(1)
        self._dropout = nn.Dropout(self._global_params.dropout_rate)
        self._fc = nn.Linear(out_channels, self._global_params.num_classes)
        self._swish = MemoryEfficientSwish()

它们的定义顺序是它们将被打印的顺序。当涉及到具体执行的内容时,您必须检查forward

def forward(self, inputs):
    # Convolution layers
    x = self.extract_features(inputs)

    # Pooling and final linear layer
    x = self._avg_pooling(x)
    x = x.flatten(start_dim=1)
    x = self._dropout(x)
    x = self._fc(x)

    return x

而且,如您所见, 之后没有任何内容self._fc(x),这意味着不会Swish应用。

于 2020-07-18T13:41:38.470 回答