我正在尝试加载一个略有不同的预训练 resnet56 模型。pretrainde 模型在构建时是普通的 resnet 模型,而我要阅读的模型将整个阶段分为两部分,其中第一部分是 Sequential,其余部分放入列表中,我将它们命名为 normal resnet 做: 建立网络
self.layer2 = self._make_layer(block, 32, blocks=n_size, stride=2,
extract_feature=extract_feature,
feature_layer=self.feature_layer_2)
for index, item in enumerate(self.feature_layer_2):
name_num = self.n_size - extract_feature + index
setattr(self, f'layer2.{name_num}', item)
self.layer3 = self._make_layer(block, 64, blocks=n_size, stride=2,
extract_feature=extract_feature,
feature_layer=self.feature_layer_3)
for index, item in enumerate(self.feature_layer_3):
name_num = self.n_size - extract_feature + index
setattr(self, f'layer3.{name_num}', item)
现在我得到了这个意外的关键错误: 错误消息
RuntimeError:为 ResNetc 加载 state_dict 时出错:state_dict 中出现意外键:“layer2.5.conv1.weight”、“layer2.5.bn1.weight”、“layer2.5.bn1.bias”、 “layer2.5.bn1.running_mean”、“layer2.5.bn1.running_var”、“layer2.5.bn1.num_batches_tracked”、“layer2.5.conv2.weight”、“layer2.5.bn2.weight”、 “layer2.5.bn2.bias”、“layer2.5.bn2.running_mean”、“layer2.5.bn2.running_var”、“layer2.5.bn2.num_batches_tracked”、“layer2.6.conv1.weight”、 “layer2.6.bn1.weight”、“layer2.6.bn1.bias”、“layer2.6.bn1.running_mean”、“layer2.6.bn1.running_var”、“layer2.6.bn1.num_batches_tracked”、 “layer2.6.conv2.weight”、“layer2.6.bn2.weight”、“layer2.6.bn2.bias”、“layer2.6.bn2.running_mean”、“layer2.6.bn2.running_var”、“layer2.6.bn2.num_batches_tracked”、“layer2.7.conv1.weight”、“layer2.7.bn1.weight”、 “layer2.7.bn1.bias”、“layer2.7.bn1.running_mean”、“layer2.7.bn1.running_var”、“layer2.7.bn1.num_batches_tracked”、“layer2.7.conv2.weight”、 “layer2.7.bn2.weight”、“layer2.7.bn2.bias”、“layer2.7.bn2.running_mean”、“layer2.7.bn2.running_var”、“layer2.7.bn2.num_batches_tracked”、 “layer2.8.conv1.weight”、“layer2.8.bn1.weight”、“layer2.8.bn1.bias”、“layer2.8.bn1.running_mean”、“layer2.8.bn1.running_var”、 “layer2.8.bn1.num_batches_tracked”,“layer2.8.conv2.weight”,“layer2.8.bn2.weight”,“layer2.8.bn2.bias”、“layer2.8.bn2.running_mean”、“layer2.8.bn2.running_var”、“layer2.8.bn2.num_batches_tracked”、“layer3.5.conv1.weight”、 “layer3.5.bn1.weight”、“layer3.5.bn1.bias”、“layer3.5.bn1.running_mean”、“layer3.5.bn1.running_var”、“layer3.5.bn1.num_batches_tracked”、 “layer3.5.conv2.weight”、“layer3.5.bn2.weight”、“layer3.5.bn2.bias”、“layer3.5.bn2.running_mean”、“layer3.5.bn2.running_var”、 “layer3.5.bn2.num_batches_tracked”、“layer3.6.conv1.weight”、“layer3.6.bn1.weight”、“layer3.6.bn1.bias”、“layer3.6.bn1.running_mean”、 “layer3.6.bn1.running_var”、“layer3.6.bn1.num_batches_tracked”、“layer3.6.conv2.weight”、“layer3.6.bn2.weight”、“layer3.6.bn2.bias”、“layer3.6.bn2.running_mean”、“layer3.6.bn2.running_var”、“layer3.6.bn2.num_batches_tracked”、 “layer3.7.conv1.weight”、“layer3.7.bn1.weight”、“layer3.7.bn1.bias”、“layer3.7.bn1.running_mean”、“layer3.7.bn1.running_var”、 “layer3.7.bn1.num_batches_tracked”、“layer3.7.conv2.weight”、“layer3.7.bn2.weight”、“layer3.7.bn2.bias”、“layer3.7.bn2.running_mean”、 “layer3.7.bn2.running_var”、“layer3.7.bn2.num_batches_tracked”、“layer3.8.conv1.weight”、“layer3.8.bn1.weight”、“layer3.8.bn1.bias”、 “layer3.8.bn1.running_mean”、“layer3.8.bn1.running_var”、“layer3.8.bn1.num_batches_tracked”、“layer3.8.conv2.weight”、“layer3.8.bn2.weight”、“layer3.8.bn2.bias”、“layer3.8.bn2.running_mean”、“layer3.8.bn2.running_var”、“ layer3.8.bn2.num_batches_tracked”。
我还尝试使用以下代码确保两个 state_dict 的键相同:
print('print keys from state_dict while not in model---------------------------------------')
for item in state.keys():
if item not in model.state_dict().keys():
print(item)
print(f'Length of keys in state_dict: {len(state.keys())}')
print(f'Length of keys in model: {len(model.state_dict().keys())}')
然后我得到了结果:
不在模型中时从 state_dict 打印键----------------------------
state_dict 中键的长度:344
模型中键的长度:344