我正在尝试删除高效网络-pytorch 实现中的顶层。但是,如果我只是用我自己的全连接层替换最后_fc
一层,正如作者在这个 github 评论中所建议的那样,我担心swish
即使在这一层之后仍然有激活,而不是像我预期的那样没有任何东西。当我打印模型时,最后几行如下:
(_bn1): BatchNorm2d(1280, eps=0.001, momentum=0.010000000000000009, affine=True, track_running_stats=True)
(_avg_pooling): AdaptiveAvgPool2d(output_size=1)
(_dropout): Dropout(p=0.2, inplace=False)
(_fc): Sequential(
(0): Linear(in_features=1280, out_features=512, bias=True)
(1): ReLU()
(2): Dropout(p=0.25, inplace=False)
(3): Linear(in_features=512, out_features=128, bias=True)
(4): ReLU()
(5): Dropout(p=0.25, inplace=False)
(6): Linear(in_features=128, out_features=1, bias=True)
)
(_swish): MemoryEfficientSwish()
)
)
_fc
我的替换模块在哪里。
我希望做的是:
base_model = EfficientNet.from_pretrained('efficientnet-b3')
model = nn.Sequential(*list(base_model.children()[:-3]))
在我看来base_model.children()
,从嵌套结构中将模型展平。但是,现在我似乎无法像使用虚拟输入一样使用模型,x=torch.randn(1,3,255,255)
我收到错误:TypeError: forward() takes 1 positional argument but 2 were given
.
应该注意的是,model[:2](x)
有效,但不是model[:3](x)
。model[2]
似乎是移动块。
这是一个带有上述代码的colab 笔记本。