10

我刚刚在 keras 中实现了广义骰子损失(骰子损失的多类版本),如ref中所述:

(我的目标定义为:(batch_size,image_dim1,image_dim2,image_dim3,nb_of_classes))

def generalized_dice_loss_w(y_true, y_pred): 
    # Compute weights: "the contribution of each label is corrected by the inverse of its volume"
    Ncl = y_pred.shape[-1]
    w = np.zeros((Ncl,))
    for l in range(0,Ncl): w[l] = np.sum( np.asarray(y_true[:,:,:,:,l]==1,np.int8) )
    w = 1/(w**2+0.00001)

    # Compute gen dice coef:
    numerator = y_true*y_pred
    numerator = w*K.sum(numerator,(0,1,2,3))
    numerator = K.sum(numerator)

    denominator = y_true+y_pred
    denominator = w*K.sum(denominator,(0,1,2,3))
    denominator = K.sum(denominator)

    gen_dice_coef = numerator/denominator

    return 1-2*gen_dice_coef

但一定有什么不对劲。我正在处理必须为 4 个类(1 个背景类和 3 个对象类,我有一个不平衡的数据集)分割的 3D 图像。第一件奇怪的事情:虽然我的训练损失和准确性在训练期间有所提高(并且收敛速度非常快),但我的验证损失/准确性是恒定的低谷时期(见图。其次,在对测试数据进行预测时,只预测背景类:我得到一个恒定的体积。

我使用了完全相同的数据和脚本,但使用了分类交叉熵损失并获得了合理的结果(对象类被分段)。这意味着我的实现有问题。知道它可能是什么吗?

另外,我相信 keras 社区有一个通用的 dice loss 实现会很有用,因为它似乎被用于大多数最近的语义分割任务(至少在医学图像社区)。

PS:对我来说权重是如何定义的似乎很奇怪;我得到大约 10^-10 的值。还有其他人尝试过实现这一点吗?我还测试了没有权重的函数,但遇到了同样的问题。

4

1 回答 1

6

我认为这里的问题是你的体重。想象一下,您正在尝试解决多类分割问题,但在每张图像中只有少数类存在。一个玩具示例(也是导致我遇到此问题的示例)是通过以下方式从 mnist 创建分段数据集。

x = 28x28 图像和 y = 28x28x11 其中如果每个像素低于归一化灰度值 0.4,则将其分类为背景,否则将其分类为 x 的原始类别的数字。所以如果你看到一张第一的图片,你会看到一堆像素归为一个,还有背景。

现在在这个数据集中,图像中只会出现两个类。这意味着,在你的骰子损失之后,9 个权重将是 1./(0. + eps) = large ,因此对于每张图像,我们都会强烈惩罚所有 9 个不存在的类。在这种情况下,网络想要找到的一个明显强的局部最小值是将所有内容预测为背景类。

我们确实想惩罚任何不在图像中但不那么强烈的错误预测类别。所以我们只需要修改权重。我是这样做的:

def gen_dice(y_true, y_pred, eps=1e-6):
    """both tensors are [b, h, w, classes] and y_pred is in logit form"""

    # [b, h, w, classes]
    pred_tensor = tf.nn.softmax(y_pred)
    y_true_shape = tf.shape(y_true)

    # [b, h*w, classes]
    y_true = tf.reshape(y_true, [-1, y_true_shape[1]*y_true_shape[2], y_true_shape[3]])
    y_pred = tf.reshape(pred_tensor, [-1, y_true_shape[1]*y_true_shape[2], y_true_shape[3]])

    # [b, classes]
    # count how many of each class are present in 
    # each image, if there are zero, then assign
    # them a fixed weight of eps
    counts = tf.reduce_sum(y_true, axis=1)
    weights = 1. / (counts ** 2)
    weights = tf.where(tf.math.is_finite(weights), weights, eps)

    multed = tf.reduce_sum(y_true * y_pred, axis=1)
    summed = tf.reduce_sum(y_true + y_pred, axis=1)

    # [b]
    numerators = tf.reduce_sum(weights*multed, axis=-1)
    denom = tf.reduce_sum(weights*summed, axis=-1)
    dices = 1. - 2. * numerators / denom
    dices = tf.where(tf.math.is_finite(dices), dices, tf.zeros_like(dices))
    return tf.reduce_mean(dices)
于 2019-11-17T14:44:56.190 回答