加载自定义 model.pt 以进行推理时出现错误。错误是 TypeError:load_state_dict() 缺少 1 个必需的位置参数:'state_dict'。
model = get_model(model_path, model_type='UNet',problem_type='parts')
这是 Unet 模型
从火炬导入火炬 从火炬视觉.models.vgg 导入nn 导入vgg16_bn
class ConvBNAct(nn.Module):
def __init__(self, in_channels, out_channels):
super().__init__()
self.seq = nn.Sequential(
nn.Conv2d(in_channels, out_channels,
kernel_size=3, padding=1),
nn.BatchNorm2d(out_channels),
nn.ReLU(),
)
def forward(self, inputs):
return self.seq(inputs)
class Block(nn.Module):
def __init__(self, src_channels, dst_channels):
super().__init__()
self.seq1 = ConvBNAct(src_channels, dst_channels)
self.seq2 = ConvBNAct(dst_channels, dst_channels)
self.seq3 = ConvBNAct(dst_channels, dst_channels)
def forward(self, x):
result = self.seq1(x)
result = self.seq2(result)
result = self.seq3(result)
return result
class UNetUp(nn.Module):
def __init__(self, down_channels, right_channels):
super().__init__()
self.bottom_up = nn.Upsample(scale_factor=2, mode='nearest')
self.conv = nn.Conv2d(down_channels, right_channels, kernel_size=1, stride=1)
def forward(self, left, bottom):
from_bottom = self.bottom_up(bottom)
from_bottom = self.conv(from_bottom)
result = torch.cat([left, from_bottom], 1)
return result
class Bottleneck(nn.Module):
def __init__(self, in_channels, out_channels):
super().__init__()
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=(3, 3), dilation=2, padding=2)
self.bn = nn.BatchNorm2d(out_channels)
self.relu = nn.ReLU()
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=(3, 3), dilation=2, padding=2)
self.bn2 = nn.BatchNorm2d(out_channels)
self.relu2 = nn.ReLU()
def forward(self, x):
out = self.conv(x)
out = self.bn(out)
out = self.conv2(self.relu(out))
out = self.bn2(out)
return torch.cat((x, self.relu2(out)), dim=1)
class UNet(nn.Module):
def __init__(self, encoder_blocks, encoder_channels, n_cls):
self.encoder_channels = encoder_channels
self.depth = len(self.encoder_channels)
assert len(encoder_blocks) == self.depth
super().__init__()
self.encoder_blocks = nn.ModuleList(encoder_blocks)
self.blocks = nn.ModuleList()
# add bottleneck
self.blocks.append(Block(
self.encoder_channels[-1],
self.encoder_channels[-1]
))
self.ups = nn.ModuleList()
for i in range(1, self.depth):
bottom_channels = self.encoder_channels[self.depth - i]
left_channels = self.encoder_channels[self.depth - i - 1]
right_channels = left_channels
self.ups.append(UNetUp(bottom_channels, right_channels))
self.blocks.append(Block(
left_channels + right_channels,
right_channels
))
self.last_conv = nn.Conv2d(encoder_channels[0], n_cls, 1)
# self.dropout = nn.Dropout2d(p=0.1)
self.bottle = Bottleneck(512, 512)
def forward(self, x):
encoder_outputs = []
for encoder_block in self.encoder_blocks:
x = encoder_block(x)
encoder_outputs.append(x)
x = self.bottle(encoder_outputs[self.depth - 1])
for i in range(self.depth):
if i > 0:
encoder_output = encoder_outputs[self.depth - i - 1]
x = self.ups[i - 1](encoder_output, x)
x = self.blocks[i](x)
# x = self.dropout(x)
x = self.last_conv(x)
return x # no softmax or log_softmax
def _get_encoder_blocks(model):
# last modules (ReLUs) of VGG blocks
layers_last_module_names = ['5', '12', '22', '32', '42']
result = []
cur_block = nn.Sequential()
for name, child in model.named_children():
if name == 'features':
for name2, child2 in child.named_children():
cur_block.add_module(name2, child2)
if name2 in layers_last_module_names:
result.append(cur_block)
cur_block = nn.Sequential()
break
return result
def construct_unet(n_cls): # no weights inited
model = vgg16_bn(pretrained=False)
encoder_blocks = _get_encoder_blocks(model)
encoder_channels = [64, 128, 256, 512, 1024] # vgg16 channels
# prev_channels = encoder_channels[-1]
return UNet(encoder_blocks, encoder_channels, n_cls)