1

我正在尝试logitbinarycrossentropy在 GPU 上的全卷积网络上使用 Flux 的函数进行优化:

Flux.train!(loss, params(UNet_model), train_batch, optimiser)

其中损失定义为:

loss(x, y) = mean(logitbinarycrossentropy.(model(x) |> gpu, y))

但是,我遇到了一个ERROR: LoadError: scalar getindex is disallowed错误。我已经验证我的模型运行正确(返回正确的输出等),并且损失函数能够计算一个值。我尝试将其重写logitbinarycrossentropy自定义损失函数,但它仍然失败,并出现相同的错误update!(opt, ps, gs)。我也尝试使用σwithbinarycrossentropy而不是logitbinarycrossentropy,并得到了同样的错误。

我正在使用 Julia 1.3.0、Flux 0.10.3、Zygote 0.4.6。

我推测这可能与我cat在模型定义中的使用有关,这可能吗?这大致是我定义模型的方式:

function (t::test)(x)
    enc1 = t.conv_block1[1](x)
    bn = t.bottle(enc1)
    dec1 = t.upconv_block[1](bn)
    dec1 = cat(dims=3, dec1, enc1)
    dec1 = t.conv_block[2](dec1)
    dec1 = t.conv(dec1)
end

完整的堆栈跟踪:

ERROR: LoadError: scalar getindex is disallowed
Stacktrace:
 [1] error(::String) at .\error.jl:33
 [2] assertscalar(::String) at C:\Users\CCL\.julia\packages\GPUArrays\1wgPO\src\indexing.jl:14
 [3] getindex at C:\Users\CCL\.julia\packages\GPUArrays\1wgPO\src\indexing.jl:54 [inlined]
 [4] _getindex at .\abstractarray.jl:1004 [inlined]
 [5] getindex at .\abstractarray.jl:981 [inlined]
 [6] hash(::CuArray{Float32,4,Nothing}, ::UInt64) at .\abstractarray.jl:2203
 [7] hash at .\hashing.jl:18 [inlined]
 [8] hashindex at .\dict.jl:168 [inlined]
 [9] ht_keyindex(::Dict{Any,Any}, ::CuArray{Float32,4,Nothing}) at .\dict.jl:282
 [10] get(::Dict{Any,Any}, ::CuArray{Float32,4,Nothing}, ::Nothing) at .\dict.jl:500
 [11] (::Zygote.var"#876#877"{Zygote.Context,IdDict{Any,Any},CuArray{Float32,4,Nothing}})(::Nothing) at C:\Users\CCL\.julia\packages\Zygote\oMScO\src\lib\base.jl:44
 [12] (::Zygote.var"#2375#back#878"{Zygote.var"#876#877"{Zygote.Context,IdDict{Any,Any},CuArray{Float32,4,Nothing}}})(::Nothing) at C:\Users\CCL\.julia\packages\ZygoteRules\6nssF\src\adjoint.jl:49
 [13] #fmap#53 at C:\Users\CCL\.julia\packages\Flux\NpkMm\src\functor.jl:37 [inlined]
 [14] (::typeof(∂(#fmap#53)))(::CuArray{Float32,4,Nothing}) at C:\Users\CCL\.julia\packages\Zygote\oMScO\src\compiler\interface2.jl:0
 [15] fmap at C:\Users\CCL\.julia\packages\Flux\NpkMm\src\functor.jl:36 [inlined]
 [16] (::typeof(∂(fmap)))(::CuArray{Float32,4,Nothing}) at C:\Users\CCL\.julia\packages\Zygote\oMScO\src\compiler\interface2.jl:0
 [17] gpu at C:\Users\CCL\.julia\packages\Flux\NpkMm\src\functor.jl:108 [inlined]
 [18] (::typeof(∂(gpu)))(::CuArray{Float32,4,Nothing}) at C:\Users\CCL\.julia\packages\Zygote\oMScO\src\compiler\interface2.jl:0
 [19] |> at .\operators.jl:854 [inlined]
 [20] (::typeof(∂(|>)))(::CuArray{Float32,4,Nothing}) at C:\Users\CCL\.julia\packages\Zygote\oMScO\src\compiler\interface2.jl:0
 [21] (::typeof(∂(loss)))(::Float32) at C:\Users\CCL\fcn\main_flux2.jl:47
 [22] #157 at C:\Users\CCL\.julia\packages\Zygote\oMScO\src\lib\lib.jl:156 [inlined]
 [23] #297#back at C:\Users\CCL\.julia\packages\ZygoteRules\6nssF\src\adjoint.jl:49 [inlined]
 [24] #17 at C:\Users\CCL\.julia\packages\Flux\NpkMm\src\optimise\train.jl:88 [inlined]
 [25] (::typeof(∂(λ)))(::Float32) at C:\Users\CCL\.julia\packages\Zygote\oMScO\src\compiler\interface2.jl:0
 [26] (::Zygote.var"#38#39"{Zygote.Params,Zygote.Context,typeof(∂(λ))})(::Float32) at C:\Users\CCL\.julia\packages\Zygote\oMScO\src\compiler\interface.jl:101
 [27] gradient(::Function, ::Zygote.Params) at C:\Users\CCL\.julia\packages\Zygote\oMScO\src\compiler\interface.jl:47
 [28] macro expansion at C:\Users\CCL\.julia\packages\Flux\NpkMm\src\optimise\train.jl:87 [inlined]
 [29] macro expansion at C:\Users\CCL\.julia\packages\Juno\f8hj2\src\progress.jl:134 [inlined]
 [30] #train!#12(::Flux.Optimise.var"#18#26", ::typeof(Flux.Optimise.train!), ::typeof(loss), ::Zygote.Params, ::Array{Tuple{CuArray{Float32,4,Nothing},CuArray{Float32,4,Nothing}},1}, ::ADAM) at C:\Users\CCL\.julia\packages\Flux\NpkMm\src\optimise\train.jl:80
 [31] train!(::Function, ::Zygote.Params, ::Array{Tuple{CuArray{Float32,4,Nothing},CuArray{Float32,4,Nothing}},1}, ::ADAM) at C:\Users\CCL\.julia\packages\Flux\NpkMm\src\optimise\train.jl:78
 [32] top-level scope at C:\Users\CCL\fcn\main_flux2.jl:83
 [33] include at .\boot.jl:328 [inlined]
 [34] include_relative(::Module, ::String) at .\loading.jl:1105
 [35] include(::Module, ::String) at .\Base.jl:31
 [36] exec_options(::Base.JLOptions) at .\client.jl:287
 [37] _start() at .\client.jl:460
in expression starting at C:\Users\CCL\fcn\main_flux2.jl:74

谢谢!

4

0 回答 0