这是我的模型(float32)的一部分,我将融合它进行量化。我的方法是使用named_modules遍历每个子模块并检查它们是conv2d batchnormlization还是relu。
(scratch): Module(
(layer1_rn): Conv2d(24, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(layer2_rn): Conv2d(40, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(layer3_rn): Conv2d(112, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(layer4_rn): Conv2d(320, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(activation): ReLU()
(refinenet4): FeatureFusionBlock_custom(
(out_conv): Conv2d(512, 256, kernel_size=(1, 1), stride=(1, 1))
(resConfUnit1): ResidualConvUnit_custom(
(conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(activation): ReLU()
(skip_add): FloatFunctional(
(activation_post_process): Identity()
)
)
(resConfUnit2): ResidualConvUnit_custom(
(conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(activation): ReLU()
(skip_add): FloatFunctional(
(activation_post_process): Identity()
)
)
(skip_add): FloatFunctional(
(activation_post_process): Identity()
)
)
for name, module in m.named_modules():
print(name, module)
我发现如果它错过了scratch.refinenet4.resConfUnit1.activation 同样的事情发生在resConfUnit2中的激活。
这是一个错误吗?
scratch.layer1_rn Conv2d(24, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
scratch.layer2_rn Conv2d(40, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
scratch.layer3_rn Conv2d(112, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
scratch.layer4_rn Conv2d(320, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
scratch.activation ReLU()
here
scratch.refinenet4 FeatureFusionBlock_custom(
(out_conv): Conv2d(512, 256, kernel_size=(1, 1), stride=(1, 1))
(resConfUnit1): ResidualConvUnit_custom(
(conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(activation): ReLU()
(skip_add): FloatFunctional(
(activation_post_process): Identity()
)
)
(resConfUnit2): ResidualConvUnit_custom(
(conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(activation): ReLU()
(skip_add): FloatFunctional(
(activation_post_process): Identity()
)
)
(skip_add): FloatFunctional(
(activation_post_process): Identity()
)
)
scratch.refinenet4.out_conv Conv2d(512, 256, kernel_size=(1, 1), stride=(1, 1))
scratch.refinenet4.resConfUnit1 ResidualConvUnit_custom(
(conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(activation): ReLU()
(skip_add): FloatFunctional(
(activation_post_process): Identity()
)
)
> scratch.refinenet4.resConfUnit1.conv1 Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
> scratch.refinenet4.resConfUnit1.conv2 Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
> scratch.refinenet4.resConfUnit1.skip_add FloatFunctional(
> (activation_post_process): Identity()
> )
scratch.refinenet4.resConfUnit1.skip_add.activation_post_process Identity()
scratch.refinenet4.resConfUnit2 ResidualConvUnit_custom(
(conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(activation): ReLU()
(skip_add): FloatFunctional(
(activation_post_process): Identity()
)
)
scratch.refinenet4.resConfUnit2.conv1 Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
scratch.refinenet4.resConfUnit2.conv2 Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
scratch.refinenet4.resConfUnit2.skip_add FloatFunctional(
(activation_post_process): Identity()
)
scratch.refinenet4.resConfUnit2.skip_add.activation_post_process Identity()
scratch.refinenet4.skip_add FloatFunctional(
(activation_post_process): Identity()
)
...