我正在尝试logitbinarycrossentropy
在 GPU 上的全卷积网络上使用 Flux 的函数进行优化:
Flux.train!(loss, params(UNet_model), train_batch, optimiser)
其中损失定义为:
loss(x, y) = mean(logitbinarycrossentropy.(model(x) |> gpu, y))
但是,我遇到了一个ERROR: LoadError: scalar getindex is disallowed
错误。我已经验证我的模型运行正确(返回正确的输出等),并且损失函数能够计算一个值。我尝试将其重写logitbinarycrossentropy
为自定义损失函数,但它仍然失败,并出现相同的错误update!(opt, ps, gs)
。我也尝试使用σ
withbinarycrossentropy
而不是logitbinarycrossentropy
,并得到了同样的错误。
我正在使用 Julia 1.3.0、Flux 0.10.3、Zygote 0.4.6。
我推测这可能与我cat
在模型定义中的使用有关,这可能吗?这大致是我定义模型的方式:
function (t::test)(x)
enc1 = t.conv_block1[1](x)
bn = t.bottle(enc1)
dec1 = t.upconv_block[1](bn)
dec1 = cat(dims=3, dec1, enc1)
dec1 = t.conv_block[2](dec1)
dec1 = t.conv(dec1)
end
完整的堆栈跟踪:
ERROR: LoadError: scalar getindex is disallowed
Stacktrace:
[1] error(::String) at .\error.jl:33
[2] assertscalar(::String) at C:\Users\CCL\.julia\packages\GPUArrays\1wgPO\src\indexing.jl:14
[3] getindex at C:\Users\CCL\.julia\packages\GPUArrays\1wgPO\src\indexing.jl:54 [inlined]
[4] _getindex at .\abstractarray.jl:1004 [inlined]
[5] getindex at .\abstractarray.jl:981 [inlined]
[6] hash(::CuArray{Float32,4,Nothing}, ::UInt64) at .\abstractarray.jl:2203
[7] hash at .\hashing.jl:18 [inlined]
[8] hashindex at .\dict.jl:168 [inlined]
[9] ht_keyindex(::Dict{Any,Any}, ::CuArray{Float32,4,Nothing}) at .\dict.jl:282
[10] get(::Dict{Any,Any}, ::CuArray{Float32,4,Nothing}, ::Nothing) at .\dict.jl:500
[11] (::Zygote.var"#876#877"{Zygote.Context,IdDict{Any,Any},CuArray{Float32,4,Nothing}})(::Nothing) at C:\Users\CCL\.julia\packages\Zygote\oMScO\src\lib\base.jl:44
[12] (::Zygote.var"#2375#back#878"{Zygote.var"#876#877"{Zygote.Context,IdDict{Any,Any},CuArray{Float32,4,Nothing}}})(::Nothing) at C:\Users\CCL\.julia\packages\ZygoteRules\6nssF\src\adjoint.jl:49
[13] #fmap#53 at C:\Users\CCL\.julia\packages\Flux\NpkMm\src\functor.jl:37 [inlined]
[14] (::typeof(∂(#fmap#53)))(::CuArray{Float32,4,Nothing}) at C:\Users\CCL\.julia\packages\Zygote\oMScO\src\compiler\interface2.jl:0
[15] fmap at C:\Users\CCL\.julia\packages\Flux\NpkMm\src\functor.jl:36 [inlined]
[16] (::typeof(∂(fmap)))(::CuArray{Float32,4,Nothing}) at C:\Users\CCL\.julia\packages\Zygote\oMScO\src\compiler\interface2.jl:0
[17] gpu at C:\Users\CCL\.julia\packages\Flux\NpkMm\src\functor.jl:108 [inlined]
[18] (::typeof(∂(gpu)))(::CuArray{Float32,4,Nothing}) at C:\Users\CCL\.julia\packages\Zygote\oMScO\src\compiler\interface2.jl:0
[19] |> at .\operators.jl:854 [inlined]
[20] (::typeof(∂(|>)))(::CuArray{Float32,4,Nothing}) at C:\Users\CCL\.julia\packages\Zygote\oMScO\src\compiler\interface2.jl:0
[21] (::typeof(∂(loss)))(::Float32) at C:\Users\CCL\fcn\main_flux2.jl:47
[22] #157 at C:\Users\CCL\.julia\packages\Zygote\oMScO\src\lib\lib.jl:156 [inlined]
[23] #297#back at C:\Users\CCL\.julia\packages\ZygoteRules\6nssF\src\adjoint.jl:49 [inlined]
[24] #17 at C:\Users\CCL\.julia\packages\Flux\NpkMm\src\optimise\train.jl:88 [inlined]
[25] (::typeof(∂(λ)))(::Float32) at C:\Users\CCL\.julia\packages\Zygote\oMScO\src\compiler\interface2.jl:0
[26] (::Zygote.var"#38#39"{Zygote.Params,Zygote.Context,typeof(∂(λ))})(::Float32) at C:\Users\CCL\.julia\packages\Zygote\oMScO\src\compiler\interface.jl:101
[27] gradient(::Function, ::Zygote.Params) at C:\Users\CCL\.julia\packages\Zygote\oMScO\src\compiler\interface.jl:47
[28] macro expansion at C:\Users\CCL\.julia\packages\Flux\NpkMm\src\optimise\train.jl:87 [inlined]
[29] macro expansion at C:\Users\CCL\.julia\packages\Juno\f8hj2\src\progress.jl:134 [inlined]
[30] #train!#12(::Flux.Optimise.var"#18#26", ::typeof(Flux.Optimise.train!), ::typeof(loss), ::Zygote.Params, ::Array{Tuple{CuArray{Float32,4,Nothing},CuArray{Float32,4,Nothing}},1}, ::ADAM) at C:\Users\CCL\.julia\packages\Flux\NpkMm\src\optimise\train.jl:80
[31] train!(::Function, ::Zygote.Params, ::Array{Tuple{CuArray{Float32,4,Nothing},CuArray{Float32,4,Nothing}},1}, ::ADAM) at C:\Users\CCL\.julia\packages\Flux\NpkMm\src\optimise\train.jl:78
[32] top-level scope at C:\Users\CCL\fcn\main_flux2.jl:83
[33] include at .\boot.jl:328 [inlined]
[34] include_relative(::Module, ::String) at .\loading.jl:1105
[35] include(::Module, ::String) at .\Base.jl:31
[36] exec_options(::Base.JLOptions) at .\client.jl:287
[37] _start() at .\client.jl:460
in expression starting at C:\Users\CCL\fcn\main_flux2.jl:74
谢谢!