2

在 Torch 中实现自定义损失函数的必要步骤是什么?

看来您必须为 updateOutput 和 updateGradInput 编写一个实现。

这就是全部?那么你基本上创建了一个新类:

local CustomCriterion, parent =   torch.class('CustomCriterion','nn.Criterion')

并实现以下两个功能:

function CustomCriterion:updateOutput(input, target)
function CustomCriterion:updateGradInput(input, target)

这是正确的,还是还有更多工作要做?

此外,对于提供的标准,这些函数是用 C 实现的,但我想 Lua 实现也可以工作,尽管可能会慢一点?

4

1 回答 1

0

我已经实现了表单的功能(在伪代码中)

--assuming input is partitioned in input_a,input_b
--         target is accordingly partitionend in target_a, target_b  
f(input)=MSE(input_a,target_a)+ custom_sutff(input_b,target_b)

很多次就像你描述的那样。所以,据我所知,我认为你的两个问题的答案都是肯定的。

基本上nn / MSECriterion.lua似乎支持这一点。

于 2017-06-26T22:49:02.523 回答