1

我手上有一个关于爆炸梯度问题的经典例子,我希望通过梯度裁剪来解决它。在 Flux 中这样做的界面是什么?

4

1 回答 1

1

优化器(如渐变裁剪)可以以不同的方式使用,首先如下所示:

julia> using Flux

julia> W = rand(2, 5)
2×5 Matrix{Float64}:
 0.107144  0.643693  0.399019  0.764073  0.78122
 0.367751  0.335326  0.442312  0.433656  0.443901

julia> b = rand(2)
2-element Vector{Float64}:
 0.035723018492827885
 0.9063968296104223

julia> predict(x) = (W * x) .+ b
predict (generic function with 1 method)

julia> loss(x, y) = sum((predict(x) .- y).^2)
loss (generic function with 1 method)

julia> x, y = rand(5), rand(2) # Dummy data
([0.4878962006153771, 0.1293768496171035, 0.4662237969593086, 0.43195747100830384, 0.10672368947541733], [0.864923828559593, 0.6643701281693306])

julia> l = loss(x, y) # ~ 3
0.8292492365517469

julia> θ = params(W, b)
Params([[0.10714416442012298 0.6436932411339433 … 0.7640730577168127 0.7812198182421601; 0.3677513353707582 0.3353255969566744 … 0.4336560750116858 0.44390077304165043], [0.035723018492827885, 0.9063968296104223]])

julia> grads = gradient(() -> loss(x, y), θ)
Grads(...)

julia> using Flux.Optimise

julia> opt = Optimiser(ClipValue(1e-3), ADAM(1e-3))
Optimiser(Any[ClipValue{Float64}(0.001), ADAM(0.001, (0.9, 0.999), IdDict{Any, Any}())])

julia> for p in (W, b)
         update!(opt, p, grads[p])
       end
# This is a somewhat bad example since there is no exploding gradients here but the mechanics would be the same if there was. 

或者您可以通过执行以下操作将优化器(opt = Optimiser(ClipValue(1e-3), ADAM(1e-3))从这里:https ://fluxml.ai/Flux.jl/stable/training/optimisers/#Gradient-Clipping )传递到训练循环中:

for d in datapoints

  # `d` should produce a collection of arguments
  # to the loss function

  # Calculate the gradients of the parameters
  # with respect to the loss function
  grads = Flux.gradient(parameters) do
    loss(d...)
  end

  # Update the parameters based on the chosen
  # optimiser (opt)
  Flux.Optimise.update!(opt, parameters, grads)
end
# Example from here: https://fluxml.ai/Flux.jl/stable/training/training/#Training

其中 opt 是根据上面显示的示例定义的。

于 2021-07-03T15:29:32.717 回答