我正在尝试将一个神经 ODE 拟合到使用 Julia 的 DiffEqFlux 的时间序列。这是我的代码:
u0 = Float32[2.;0]
train_size = 15
tspan_train = (0.0f0,0.75f0)
function trueODEfunc(du,u,p,t)
true_A = [-0.1 2.0; -2.0 -0.1]
du .= ((u.^3)'true_A)'
end
t_train = range(tspan_train[1],tspan_train[2],length = train_size)
prob = ODEProblem(trueODEfunc, u0, tspan_train)
ode_data_train = Array(solve(prob, Tsit5(),saveat=t_train))
dudt = Chain(
Dense(2,50,tanh),
Dense(50,2))
ps = Flux.params(dudt)
n_ode = NeuralODE(dudt, tspan_train, Tsit5(), saveat = t_train, reltol=1e-7, abstol=1e-9)
**n_ode.p**
function predict_n_ode(p)
n_ode(u0,p)
end
function loss_n_ode(p)
pred = predict_n_ode(p)
loss = sum(abs2, ode_data_train .- pred)
loss,pred
end
final_p = []
losses = []
cb = function(p,l,pred)
display(l)
display(p)
push!(final_p, p)
push!(losses,l)
pl = scatter(t_train, ode_data_train[1,:],label="data")
scatter!(pl,t_train,pred[1,:],label="prediction")
display(plot(pl))
end
DiffEqFlux.sciml_train!(loss_n_ode, n_ode.p, ADAM(0.05), cb = cb, maxiters = 100)
**n_ode.p**
问题是在 train 函数之前和之后调用n_ode.p
(或)给了我保存值。Flux.params(dudt)
我本来希望从培训中收到最新的更新值。这就是为什么我创建了一个数组来收集训练期间的所有参数值,然后访问它以获取更新的参数。
我在代码中做错了吗?train 函数会自动更新参数吗?如果没有怎么执行?
提前致谢!