我正在做一个图像分割任务。总共有 7 个类,所以最终的输出是一个类似于 [batch, 7, height, width] 的张量,它是一个 softmax 输出。现在直观地我想使用 CrossEntropy 损失,但 pytorch 实现不适用于通道明智的 one-hot 编码向量
所以我打算自己做一个函数。在一些stackoverflow的帮助下,我的代码到目前为止看起来像这样
from torch.autograd import Variable
import torch
import torch.nn.functional as F
def cross_entropy2d(input, target, weight=None, size_average=True):
# input: (n, c, w, z), target: (n, w, z)
n, c, w, z = input.size()
# log_p: (n, c, w, z)
log_p = F.log_softmax(input, dim=1)
# log_p: (n*w*z, c)
log_p = log_p.permute(0, 3, 2, 1).contiguous().view(-1, c) # make class dimension last dimension
log_p = log_p[
target.view(n, w, z, 1).repeat(0, 0, 0, c) >= 0] # this looks wrong -> Should rather be a one-hot vector
log_p = log_p.view(-1, c)
# target: (n*w*z,)
mask = target >= 0
target = target[mask]
loss = F.nll_loss(log_p, target.view(-1), weight=weight, size_average=False)
if size_average:
loss /= mask.data.sum()
return loss
images = Variable(torch.randn(5, 3, 4, 4))
labels = Variable(torch.LongTensor(5, 3, 4, 4).random_(3))
cross_entropy2d(images, labels)
我得到两个错误。代码本身提到了一个,它需要一个热向量。第二个说如下
RuntimeError: invalid argument 2: size '[5 x 4 x 4 x 1]' is invalid for input with 3840 elements at ..\src\TH\THStorage.c:41
例如,我试图让它解决一个 3 类问题。所以目标和标签是(不包括简化的批处理参数!)
目标:
Channel 1 Channel 2 Channel 3
[[0 1 1 0 ] [0 0 0 1 ] [1 0 0 0 ]
[0 0 1 1 ] [0 0 0 0 ] [1 1 0 0 ]
[0 0 0 1 ] [0 0 0 0 ] [1 1 1 0 ]
[0 0 0 0 ] [0 0 0 1 ] [1 1 1 0 ]
标签:
Channel 1 Channel 2 Channel 3
[[0 1 1 0 ] [0 0 0 1 ] [1 0 0 0 ]
[0 0 1 1 ] [.2 0 0 0] [.8 1 0 0 ]
[0 0 0 1 ] [0 0 0 0 ] [1 1 1 0 ]
[0 0 0 0 ] [0 0 0 1 ] [1 1 1 0 ]
那么如何修复我的代码来计算通道明智的 CrossEntropy 损失?