我正在尝试训练 LSTM 以基于 x 序列(不仅仅是最后一项或分类器)对完整序列 y 进行建模。使用以下代码,尽管损失函数有效,但训练不起作用。似乎点形式主义不适用于火车!? 任何想法我该怎么做?在 Keras 中,它是如此简单......在此先感谢 Markus
using Flux
# Create synthetic data first
### Function to generate x consisting of three variables and a sequence length of 200
function generateX()
x1 = Array{Float32, 1}(randn(200))
x2 = Array{Float32, 1}(randn(200))
x3 = Array{Float32, 1}(sin.((0:199) / 12*2*pi))
xdata=[x1 x2 x3]'
return(xdata)
end
### Generate 50 of these sequences of x
xdata = [generateX() for i in 1:50]
### Function to generate sequence of y from x sequence
function yfromx(x)
y=Array{Float32, 1}(0.2*cumsum(x[1,:].*x[2,:].*exp.(x[1,:])) .+x[3,:])
return(y')
end
ydata = map(yfromx, xdata);
### Now rearrange such that there is a sequence of 200 X inputs, i.e. an array of x vectors (and 50 of those sequences)
xdata=Flux.batch(xdata)
xdata2 = [xdata[:,s,c] for s in 1:200, c in 1:50]
xdata= [xdata2[:,c] for c in 1:50]
### Same for y
ydata=Flux.batch(ydata)
ydata2 = [ydata[:,s,c] for s in 1:200, c in 1:50]
ydata= [ydata2[:,c] for c in 1:50]
### Define model and loss function. "model." returns sequence of y from sequence of x
import Base.Iterators: flatten
model=Chain(LSTM(3, 26), Dense(26,1))
loss(x,y) = Flux.mse(collect(flatten(model.(x))),collect(flatten(y)))
model.(xdata[1]) # works fine
loss(xdata[2],ydata[2]) # also works fine
Flux.train!(loss, params(model), zip(xdata, ydata), ADAM(0.005)) ## Does not work, see error below. How to work around?
错误信息
Mutating arrays is not supported
Stacktrace:
[1] error(::String) at ./error.jl:33
[2] (::getfield(Zygote, Symbol("##992#993")))(::Nothing) at /Net/Groups/BGI/scratch/mreichstein/julia_atacama_depots/packages/Zygote/fw4Oc/src/lib/array.jl:44
[3] (::getfield(Zygote, Symbol("##2633#back#994")){getfield(Zygote, Symbol("##992#993"))})(::Nothing) at /Net/Groups/BGI/scratch/mreichstein/julia_atacama_depots/packages/ZygoteRules/6nssF/src/adjoint.jl:49
[4] copyto! at ./abstractarray.jl:725 [inlined]
[5] (::typeof(∂(copyto!)))(::Array{Float32,1}) at /Net/Groups/BGI/scratch/mreichstein/julia_atacama_depots/packages/Zygote/fw4Oc/src/compiler/interface2.jl:0
[6] _collect at ./array.jl:550 [inlined]
[7] (::typeof(∂(_collect)))(::Array{Float32,1}) at /Net/Groups/BGI/scratch/mreichstein/julia_atacama_depots/packages/Zygote/fw4Oc/src/compiler/interface2.jl:0
[8] collect at ./array.jl:544 [inlined]
[9] (::typeof(∂(collect)))(::Array{Float32,1}) at /Net/Groups/BGI/scratch/mreichstein/julia_atacama_depots/packages/Zygote/fw4Oc/src/compiler/interface2.jl:0
[10] loss at ./In[20]:4 [inlined]
[11] (::typeof(∂(loss)))(::Float32) at /Net/Groups/BGI/scratch/mreichstein/julia_atacama_depots/packages/Zygote/fw4Oc/src/compiler/interface2.jl:0
[12] #153 at /Net/Groups/BGI/scratch/mreichstein/julia_atacama_depots/packages/Zygote/fw4Oc/src/lib/lib.jl:142 [inlined]
[13] #283#back at /Net/Groups/BGI/scratch/mreichstein/julia_atacama_depots/packages/ZygoteRules/6nssF/src/adjoint.jl:49 [inlined]
[14] #15 at /Net/Groups/BGI/scratch/mreichstein/julia_atacama_depots/packages/Flux/oX9Pi/src/optimise/train.jl:69 [inlined]
[15] (::typeof(∂(λ)))(::Float32) at /Net/Groups/BGI/scratch/mreichstein/julia_atacama_depots/packages/Zygote/fw4Oc/src/compiler/interface2.jl:0
[16] (::getfield(Zygote, Symbol("##38#39")){Zygote.Params,Zygote.Context,typeof(∂(λ))})(::Float32) at /Net/Groups/BGI/scratch/mreichstein/julia_atacama_depots/packages/Zygote/fw4Oc/src/compiler/interface.jl:101
[17] gradient(::Function, ::Zygote.Params) at /Net/Groups/BGI/scratch/mreichstein/julia_atacama_depots/packages/Zygote/fw4Oc/src/compiler/interface.jl:47
[18] macro expansion at /Net/Groups/BGI/scratch/mreichstein/julia_atacama_depots/packages/Flux/oX9Pi/src/optimise/train.jl:68 [inlined]
[19] macro expansion at /Net/Groups/BGI/scratch/mreichstein/julia_atacama_depots/packages/Juno/oLB1d/src/progress.jl:134 [inlined]
[20] #train!#12(::getfield(Flux.Optimise, Symbol("##16#22")), ::typeof(Flux.Optimise.train!), ::Function, ::Zygote.Params, ::Base.Iterators.Zip{Tuple{Array{Array{Array{Float32,1},1},1},Array{LinearAlgebra.Adjoint{Float32,Array{Float32,1}},1}}}, ::ADAM) at /Net/Groups/BGI/scratch/mreichstein/julia_atacama_depots/packages/Flux/oX9Pi/src/optimise/train.jl:66
[21] train!(::Function, ::Zygote.Params, ::Base.Iterators.Zip{Tuple{Array{Array{Array{Float32,1},1},1},Array{LinearAlgebra.Adjoint{Float32,Array{Float32,1}},1}}}, ::ADAM) at /Net/Groups/BGI/scratch/mreichstein/julia_atacama_depots/packages/Flux/oX9Pi/src/optimise/train.jl:64
[22] top-level scope at In[24]:1
loss(xdata[2],ydata[2])