我正在尝试使用 Lightning 在 PyTorch 中使用Ternausnet在 Carvana 数据集上重现 unet 结果。
我正在使用 DiceLoss 和 sigmoid 激活函数。我想我遇到了梯度消失的问题,因为所有的权重梯度都是 0,我看到网络的输出最小值为 10^8。
这里可能是什么问题?如何解决消失的梯度?此外,如果我使用不同的标准,我会看到损失在不停止的情况下变为负值的问题(例如,对于带有 logits 的 BCE)。
这是我的骰子损失的代码:
class DiceLoss(nn.Module):
def __init__(self):
super().__init__()
def forward(self, logits, targets, eps=0, threshold=None):
# comment out if your model contains a sigmoid or
# equivalent activation layer
proba = torch.sigmoid(logits)
proba = proba.view(proba.shape[0], 1, -1)
targets = targets.view(targets.shape[0], 1, -1)
if threshold:
proba = (proba > threshold).float()
# flatten label and prediction tensors
intersection = torch.sum(proba * targets, dim=1)
summation = torch.sum(proba, dim=1) + torch.sum(targets, dim=1)
dice = (2.0 * intersection + eps) / (summation + eps)
# print(intersection, summation, dice)
return (1 - dice).mean()