0

我正在尝试设计一个可以对细胞图像进行逐像素分割的 CNN。比如这些: 在此处输入图像描述

使用这样的分割掩码(除了每个原始图像的多个分割掩码,例如:单元格内部、单元格边界、背景):

在此处输入图像描述

我主要从这里复制了 U-net 设计:https ://lmb.informatik.uni-freiburg.de/people/ronneber/u-net/

然而,即使是 10 个带注释的图像(超过 300 个单元格),我仍然得到相当糟糕的骰子系数分数,而且预测也不是很好。根据 U-Net 论文,这个数量的注释单元应该足以进行良好的预测。

这是我正在使用的模型的代码。

def get_unet():
inputs = Input((img_rows, img_cols, 1))
conv1 = Conv2D(16, window_size, activation='relu', padding='same')(inputs)
conv1 = Conv2D(16, window_size, activation='relu', padding='same')(conv1)
pool1 = MaxPooling2D(pool_size=(2, 2))(conv1)

conv2 = Conv2D(64, window_size, activation='relu', padding='same')(pool1)
conv2 = Conv2D(64, window_size, activation='relu', padding='same')(conv2)
pool2 = MaxPooling2D(pool_size=(2, 2))(conv2)

conv3 = Conv2D(128, window_size, activation='relu', padding='same')(pool2)
conv3 = Conv2D(128, window_size, activation='relu', padding='same')(conv3)
pool3 = MaxPooling2D(pool_size=(2, 2))(conv3)

conv4 = Conv2D(128, window_size, activation='relu', padding='same')(pool3)
conv4 = Conv2D(128, window_size, activation='relu', padding='same')(conv4)
pool4 = MaxPooling2D(pool_size=(2, 2))(conv4)

conv5 = Conv2D(512, window_size, activation='relu', padding='same')(pool4)
conv5 = Conv2D(512, window_size, activation='relu', padding='same')(conv5)   


up6 = concatenate([Conv2DTranspose(512, (2, 2), strides=(2, 2), padding='same')(conv5), conv4], axis=3)
conv6 = Conv2D(128, window_size, activation='relu', padding='same')(up6)
conv6 = Conv2D(128, window_size, activation='relu', padding='same')(conv6)

up7 = concatenate([Conv2DTranspose(128, (2, 2), strides=(2, 2), padding='same')(conv6), conv3], axis=3)
conv7 = Conv2D(128, window_size, activation='relu', padding='same')(up7)
conv7 = Conv2D(128, window_size, activation='relu', padding='same')(conv7)

up8 = concatenate([Conv2DTranspose(128, (2, 2), strides=(2, 2), padding='same')(conv7), conv2], axis=3)
conv8 = Conv2D(64, window_size, activation='relu', padding='same')(up8)
conv8 = Conv2D(64, window_size, activation='relu', padding='same')(conv8)

up9 = concatenate([Conv2DTranspose(64, (2, 2), strides=(2, 2), padding='same')(conv8), conv1], axis=3)
conv9 = Conv2D(16, window_size, activation='relu', padding='same')(up9)
conv9 = Conv2D(16, window_size, activation='relu', padding='same')(conv9)

conv10 = Conv2D(f_num, (1, 1), activation='softmax')(conv9) # change to N,(1,1) for more classes and softmax

model = Model(inputs=[inputs], outputs=[conv10])

model.compile(optimizer=Adam(lr=1e-5), loss=dice_coef_loss, metrics=[dice_coef])

return model`

我已经为模型尝试了许多不同的超参数,但都没有成功。骰子分数徘徊在 0.25 左右,我的损失在不同时期之间几乎没有减少。我觉得我在这里做一些根本错误的事情。有什么建议么?

编辑: Sigmoid 激活将骰子分数从 0.25 提高到 0.33(然而 1 个 epoch 再次达到这个分数,随后的 epoch 仅从 0.33 略微提高到 0.331 等)

dice_coef_loss 定义如下

smooth = 1.

def dice_coef(y_true, y_pred):
    y_true_f = K.flatten(y_true)
    y_pred_f = K.flatten(y_pred)
    intersection = K.sum(y_true_f * y_pred_f)
    return (2. * intersection + smooth) / (K.sum(y_true_f) + K.sum(y_pred_f) + smooth)

def dice_coef_loss(y_true, y_pred):
    return -dice_coef(y_true, y_pred)

另外,如果 model.summary 输出有用的话:


Layer (type)                 Output Shape              Param #   
=================================================================
input_2 (InputLayer)         (None, 64, 64, 1)         0         
_________________________________________________________________
conv2d_20 (Conv2D)           (None, 64, 64, 16)        32        
_________________________________________________________________
conv2d_21 (Conv2D)           (None, 64, 64, 16)        272       
_________________________________________________________________
max_pooling2d_5 (MaxPooling2 (None, 32, 32, 16)        0         
_________________________________________________________________
conv2d_22 (Conv2D)           (None, 32, 32, 64)        1088      
_________________________________________________________________
conv2d_23 (Conv2D)           (None, 32, 32, 64)        4160      
_________________________________________________________________
max_pooling2d_6 (MaxPooling2 (None, 16, 16, 64)        0         
_________________________________________________________________
conv2d_24 (Conv2D)           (None, 16, 16, 128)       8320      
_________________________________________________________________
conv2d_25 (Conv2D)           (None, 16, 16, 128)       16512     
_________________________________________________________________
max_pooling2d_7 (MaxPooling2 (None, 8, 8, 128)         0         
_________________________________________________________________
conv2d_26 (Conv2D)           (None, 8, 8, 128)         16512     
_________________________________________________________________
conv2d_27 (Conv2D)           (None, 8, 8, 128)         16512     
_________________________________________________________________
max_pooling2d_8 (MaxPooling2 (None, 4, 4, 128)         0         
_________________________________________________________________
conv2d_28 (Conv2D)           (None, 4, 4, 512)         66048     
_________________________________________________________________
conv2d_29 (Conv2D)           (None, 4, 4, 512)         262656    
_________________________________________________________________
conv2d_transpose_5 (Conv2DTr (None, 8, 8, 512)         1049088   
_________________________________________________________________
concatenate_5 (Concatenate)  (None, 8, 8, 640)         0         
_________________________________________________________________
conv2d_30 (Conv2D)           (None, 8, 8, 128)         82048     
_________________________________________________________________
conv2d_31 (Conv2D)           (None, 8, 8, 128)         16512     
_________________________________________________________________
conv2d_transpose_6 (Conv2DTr (None, 16, 16, 128)       65664     
_________________________________________________________________
concatenate_6 (Concatenate)  (None, 16, 16, 256)       0         
_________________________________________________________________
conv2d_32 (Conv2D)           (None, 16, 16, 128)       32896     
_________________________________________________________________
conv2d_33 (Conv2D)           (None, 16, 16, 128)       16512     
_________________________________________________________________
conv2d_transpose_7 (Conv2DTr (None, 32, 32, 128)       65664     
_________________________________________________________________
concatenate_7 (Concatenate)  (None, 32, 32, 192)       0         
_________________________________________________________________
conv2d_34 (Conv2D)           (None, 32, 32, 64)        12352     
_________________________________________________________________
conv2d_35 (Conv2D)           (None, 32, 32, 64)        4160      
_________________________________________________________________
conv2d_transpose_8 (Conv2DTr (None, 64, 64, 64)        16448     
_________________________________________________________________
concatenate_8 (Concatenate)  (None, 64, 64, 80)        0         
_________________________________________________________________
conv2d_36 (Conv2D)           (None, 64, 64, 16)        1296      
_________________________________________________________________
conv2d_37 (Conv2D)           (None, 64, 64, 16)        272       
_________________________________________________________________
conv2d_38 (Conv2D)           (None, 64, 64, 4)         68        
=================================================================
Total params: 1,755,092.0
Trainable params: 1,755,092.0
Non-trainable params: 0.0
4

0 回答 0