1

我正在尝试使用 Lightning 在 PyTorch 中使用Ternausnet在 Carvana 数据集上重现 unet 结果。

我正在使用 DiceLoss 和 sigmoid 激活函数。我想我遇到了梯度消失的问题,因为所有的权重梯度都是 0,我看到网络的输出最小值为 10^8。

这里可能是什么问题?如何解决消失的梯度?此外,如果我使用不同的标准,我会看到损失在不停止的情况下变为负值的问题(例如,对于带有 logits 的 BCE)。

这是我的骰子损失的代码:

class DiceLoss(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, logits, targets, eps=0, threshold=None):

        # comment out if your model contains a sigmoid or
        # equivalent activation layer
        proba = torch.sigmoid(logits)
        proba = proba.view(proba.shape[0], 1, -1)
        targets = targets.view(targets.shape[0], 1, -1)
        if threshold:
            proba = (proba > threshold).float()
        # flatten label and prediction tensors

        intersection = torch.sum(proba * targets, dim=1)
        summation = torch.sum(proba, dim=1) + torch.sum(targets, dim=1)
        dice = (2.0 * intersection + eps) / (summation + eps)
        # print(intersection, summation, dice)
        return (1 - dice).mean()
4

0 回答 0