2

我目前正在尝试在 Flux for Julia 中实现批量更新。

在我的计算过程中,我通过反复做得到一批标量

δ = Gt - model(St)[1]
push!(deltas,δ)

其中模型是神经网络

global model= Chain(
    Dense(statesize,10, leakyrelu),
    Dense(10,10,leakyrelu),
    Dense(10,1))

我最终得到了数组 deltas,我想在第二个神经网络上执行批量梯度更新(批量大小 = 19),其中每个梯度都被适当的 delta 加权。我写的更新函数是

function vupdate2!(S_batch,model,α,deltas)

   function v_loss_total(x)
       return sum(reshape(deltas,(1,19)) .* model(x))
   end

   local ps = Flux.params(model)
   local gs = Flux.Tracker.gradient(() -> v_loss_total(S_batch), ps)
   for p in ps
       Flux.Tracker.update!( p,  α.* gs[p])
   end
end

问题是,计算梯度的线会引发错误:MethodError: no method matching Float32(::Tracker.TrackedReal{Float64})

我认为问题是,我的增量数组被跟踪。查看随机输入的 v_loss_total 函数的输出,我得到:

julia> v_loss_total(S_batch)
-6752.433690476287 (tracked) (tracked)

有趣的是,这个数字被跟踪了两次(?),我猜这来自将两个跟踪的数字相乘(即增量和模型(S_batch)的条目)。有没有办法首先取消跟踪 delta 数组?我将不胜感激任何帮助。

4

2 回答 2

2

好的,事实证明,有一个函数

Flux.Tracker.data()

这正是我需要的。它接受一个跟踪的数字并返回 Float 本身。另见:https ://github.com/FluxML/Flux.jl/issues/640

于 2019-06-10T20:07:07.393 回答
1

在 julia 1.2 中对我有用的是通过使用将浮点数作为字段访问.data

GreenLogic 建议的上述函数仅返回另一个 Tracker。

于 2019-11-18T14:40:52.907 回答