0

我已经在pytorch中定义了一个损失函数,但是有一个错误我找不到解决方案。这是我的代码:

<code>
class cust_loss(torch.nn.Module):
    def __init__(self):
        super(cust_loss, self).__init__()

def forward(self, input, target):
    predicted_labels = torch.max(input, 1)[1]
    minus = torch.max(input, 1)[1] - target
    cust_distance = torch.sum(minus*minus).type(torch.FloatTensor)/predicted_labels.size()[0]
    return cust_distance



######## within main function ######

criterion = cust_loss()#nn.CrossEntropyLoss()        
Optimizer = optim.SGD(filter(lambda p: p.requires_grad, model_conv.parameters()), lr=1e-3, momentum=0.9)

loss = criterion(inputs, labels)
loss.backward()

不幸的是,我收到了这个错误:

Traceback (most recent call last):
  File "/home/morteza/PycharmProjects/transfer_learning/test_SkinDetection.py", line 250, in <module>
    main(True)
  File "/home/morteza/PycharmProjects/transfer_learning/test_SkinDetection.py", line 130, in main
    loss.backward()
  File "/home/morteza/anaconda3/lib/python3.6/site-packages/torch/autograd/variable.py", line 156, in backward
    torch.autograd.backward(self, gradient, retain_graph, create_graph, retain_variables)
  File "/home/morteza/anaconda3/lib/python3.6/site-packages/torch/autograd/__init__.py", line 98, in backward
    variables, grad_variables, retain_graph)
  File "/home/morteza/anaconda3/lib/python3.6/site-packages/torch/autograd/function.py", line 91, in apply
    return self._forward_cls.backward(self, *args)
  File "/home/morteza/anaconda3/lib/python3.6/site-packages/torch/autograd/_functions/basic_ops.py", line 38, in backward
    return maybe_unexpand(grad_output, ctx.a_size), maybe_unexpand_or_view(grad_output.neg(), ctx.b_size), None
  File "/home/morteza/anaconda3/lib/python3.6/site-packages/torch/autograd/variable.py", line 381, in neg
    return Negate.apply(self)
  File "/home/morteza/anaconda3/lib/python3.6/site-packages/torch/autograd/_functions/basic_ops.py", line 224, in forward
    return i.neg()
 AttributeError: 'torch.LongTensor' object has no attribute 'neg'

我无法解决它。我跟踪了代码并将其与没有错误的代码进行了比较,但我无法解决它。此外,我将输入和标签定义为带有“requires_grad=True”参数的变量。请指导我如何解决它。谢谢你。

4

0 回答 0