在 Torch 中实现自定义损失函数的必要步骤是什么?
看来您必须为 updateOutput 和 updateGradInput 编写一个实现。
这就是全部?那么你基本上创建了一个新类:
local CustomCriterion, parent = torch.class('CustomCriterion','nn.Criterion')
并实现以下两个功能:
function CustomCriterion:updateOutput(input, target)
function CustomCriterion:updateGradInput(input, target)
这是正确的,还是还有更多工作要做?
此外,对于提供的标准,这些函数是用 C 实现的,但我想 Lua 实现也可以工作,尽管可能会慢一点?