7

我正在尝试在损失函数中使用分位数进行训练!(对于某些稳健性,例如最小修剪平方),但它会改变数组并且 Zygote 会抛出错误Mutating arrays is not supported,来自sort!. 下面是一个简单的例子(内容当然没有意义):

using Flux, StatsBase
xdata = randn(2, 100)   
ydata = randn(100)

model = Chain(Dense(2,10), Dense(10, 1))


function trimmedLoss(x,y; trimFrac=0.f05)
        yhat = model(x)
        absRes = abs.(yhat .- y) |> vec
        trimVal = quantile(absRes, 1.f0-trimFrac) 
        s = sum(ifelse.(absRes .> trimVal,  0.f0 , absRes ))/(length(absRes)*(1.f0-trimFrac))
        #s = sum(absRes)/length(absRes)   # using this and commenting out the two above works (no surprise)    
end

println(trimmedLoss(xdata, ydata)) #works ok

Flux.train!(trimmedLoss, params(model), zip([xdata], [ydata]), ADAM())

println(trimmedLoss(xdata, ydata)) #changed loss?

这一切都在 Flux 0.10 和 Julia 1.2 中

提前感谢任何提示或解决方法!

4

1 回答 1

7

理想情况下,我们会定义一个自定义伴随程序,quantile以便开箱即用。(随时打开一个问题来提醒我们这样做。)

与此同时,有一个快速的解决方法。实际上是排序在这里造成了麻烦,所以如果你这样做quantile(xs, p, sorted=true),它就会起作用。显然,这需要xs进行排序以获得正确的结果,因此您可能需要使用quantile(sort(xs), ...).

根据您的 Zygote 版本,您可能还需要sort. 那很简单:

julia> using Zygote: @adjoint

julia> @adjoint function sort(x)
         p = sortperm(x)
         x[p], x̄ -> (x̄[invperm(p)],)
       end

julia> gradient(x -> quantile(sort(x), 0.5, sorted=true), [1, 2, 3, 3])
([0.0, 0.5, 0.5, 0.0],)

我们将在下一个 Zygote 版本中内置它,但现在如果你将它添加到你的脚本中,它会让你的代码工作。

于 2020-01-20T11:38:09.723 回答