0

我正在尝试加载一个略有不同的预训练 resnet56 模型。pretrainde 模型在构建时是普通的 resnet 模型,而我要阅读的模型将整个阶段分为两部分,其中第一部分是 Sequential,其余部分放入列表中,我将它们命名为 normal resnet 做: 建立网络

self.layer2 = self._make_layer(block, 32, blocks=n_size, stride=2,
                                extract_feature=extract_feature, 
                                feature_layer=self.feature_layer_2)
for index, item in enumerate(self.feature_layer_2):
    name_num = self.n_size - extract_feature + index
    setattr(self, f'layer2.{name_num}', item)
self.layer3 = self._make_layer(block, 64, blocks=n_size, stride=2,
                                extract_feature=extract_feature, 
                                feature_layer=self.feature_layer_3)
for index, item in enumerate(self.feature_layer_3):
    name_num = self.n_size - extract_feature + index
    setattr(self, f'layer3.{name_num}', item)

现在我得到了这个意外的关键错误: 错误消息

RuntimeError:为 ResNetc 加载 state_dict 时出错:state_dict 中出现意外键:“layer2.5.conv1.weight”、“layer2.5.bn1.weight”、“layer2.5.bn1.bias”、 “layer2.5.bn1.running_mean”、“layer2.5.bn1.running_var”、“layer2.5.bn1.num_batches_tracked”、“layer2.5.conv2.weight”、“layer2.5.bn2.weight”、 “layer2.5.bn2.bias”、“layer2.5.bn2.running_mean”、“layer2.5.bn2.running_var”、“layer2.5.bn2.num_batches_tracked”、“layer2.6.conv1.weight”、 “layer2.6.bn1.weight”、“layer2.6.bn1.bias”、“layer2.6.bn1.running_mean”、“layer2.6.bn1.running_var”、“layer2.6.bn1.num_batches_tracked”、 “layer2.6.conv2.weight”、“layer2.6.bn2.weight”、“layer2.6.bn2.bias”、“layer2.6.bn2.running_mean”、“layer2.6.bn2.running_var”、“layer2.6.bn2.num_batches_tracked”、“layer2.7.conv1.weight”、“layer2.7.bn1.weight”、 “layer2.7.bn1.bias”、“layer2.7.bn1.running_mean”、“layer2.7.bn1.running_var”、“layer2.7.bn1.num_batches_tracked”、“layer2.7.conv2.weight”、 “layer2.7.bn2.weight”、“layer2.7.bn2.bias”、“layer2.7.bn2.running_mean”、“layer2.7.bn2.running_var”、“layer2.7.bn2.num_batches_tracked”、 “layer2.8.conv1.weight”、“layer2.8.bn1.weight”、“layer2.8.bn1.bias”、“layer2.8.bn1.running_mean”、“layer2.8.bn1.running_var”、 “layer2.8.bn1.num_batches_tracked”,“layer2.8.conv2.weight”,“layer2.8.bn2.weight”,“layer2.8.bn2.bias”、“layer2.8.bn2.running_mean”、“layer2.8.bn2.running_var”、“layer2.8.bn2.num_batches_tracked”、“layer3.5.conv1.weight”、 “layer3.5.bn1.weight”、“layer3.5.bn1.bias”、“layer3.5.bn1.running_mean”、“layer3.5.bn1.running_var”、“layer3.5.bn1.num_batches_tracked”、 “layer3.5.conv2.weight”、“layer3.5.bn2.weight”、“layer3.5.bn2.bias”、“layer3.5.bn2.running_mean”、“layer3.5.bn2.running_var”、 “layer3.5.bn2.num_batches_tracked”、“layer3.6.conv1.weight”、“layer3.6.bn1.weight”、“layer3.6.bn1.bias”、“layer3.6.bn1.running_mean”、 “layer3.6.bn1.running_var”、“layer3.6.bn1.num_batches_tracked”、“layer3.6.conv2.weight”、“layer3.6.bn2.weight”、“layer3.6.bn2.bias”、“layer3.6.bn2.running_mean”、“layer3.6.bn2.running_var”、“layer3.6.bn2.num_batches_tracked”、 “layer3.7.conv1.weight”、“layer3.7.bn1.weight”、“layer3.7.bn1.bias”、“layer3.7.bn1.running_mean”、“layer3.7.bn1.running_var”、 “layer3.7.bn1.num_batches_tracked”、“layer3.7.conv2.weight”、“layer3.7.bn2.weight”、“layer3.7.bn2.bias”、“layer3.7.bn2.running_mean”、 “layer3.7.bn2.running_var”、“layer3.7.bn2.num_batches_tracked”、“layer3.8.conv1.weight”、“layer3.8.bn1.weight”、“layer3.8.bn1.bias”、 “layer3.8.bn1.running_mean”、“layer3.8.bn1.running_var”、“layer3.8.bn1.num_batches_tracked”、“layer3.8.conv2.weight”、“layer3.8.bn2.weight”、“layer3.8.bn2.bias”、“layer3.8.bn2.running_mean”、“layer3.8.bn2.running_var”、“ layer3.8.bn2.num_batches_tracked”。

我还尝试使用以下代码确保两个 state_dict 的键相同:

print('print keys from state_dict while not in model---------------------------------------')
for item in state.keys():
    if item not in model.state_dict().keys():
        print(item)
print(f'Length of keys in state_dict: {len(state.keys())}')
print(f'Length of keys in model: {len(model.state_dict().keys())}')

然后我得到了结果:

不在模型中时从 state_dict 打印键----------------------------

state_dict 中键的长度:344

模型中键的长度:344

4

0 回答 0