1

我使用 Flux 库创建了一个神经网络。我的代码基于论文“神经重要性采样”。我定义了自定义层和自定义损失函数。每层返回两个输出。当我运行我的代码时,我收到此错误:

不支持变异数组…[6] LinearLayer at .\In[6]:13

这是我的代码:

using Zygote: Buffer
using Flux;

function calculate_forward(xa, xb, q, A::Int64, B::Int64, K::Int64, batch_size::Int64)
    bins = Int64.(ceil.(K*xb))
    ctransforms = zeros(B, batch_size)
    output = zeros(A+B, batch_size)
    dets = Buffer(zeros(batch_size))
    for s=1:batch_size
        det = 1.
        for i=1:B
            b = bins[i,s]
            for j=1:b-1
                ctransforms[i,s] += q[i,j,s]
            end
            ctransforms[i,s] += (K*xb[i,s]-(b-1))*q[i,b,s]
            det *= K*q[i,b,s]
        end
        dets[s] = det
    end
    output[1:A,:] = xa
    output[A+1:end,:] = ctransforms
    return output, copy(dets)
end

function calculate_inverse(za, zb, q, A::Int64, B::Int64, K::Int64, batch_size::Int64)
    xb = zeros(B, batch_size)
    x = zeros(A+B, batch_size)
    inv_dets = Buffer(zeros(batch_size))
    for s=1:batch_size
        inv_det = 1.
        for i=1:B
            q_sum = 0.
            for j=1:K
                if q_sum <= zb[i,s] < q_sum + q[i,j,s]
                    xb[i,s] += (zb[i,s]-q_sum)/(K*q[i,j,s])
                    xb[i,s] += (j-1)/K
                    inv_det *= 1/(q[i,j,s]*K)
                    break
                else
                    q_sum += q[i,j,s]
                end
            end
        end
        inv_dets[s] = inv_det
    end
    x[1:A,:] = za
    x[A+1:end,:] = xb
    return x, copy(inv_dets)
end

struct LinearLayer
    model
    A::Int64
    B::Int64
    K::Int64
    batch_size::Int64
end

struct InverseLinearLayer
    model
    A::Int64
    B::Int64
    K::Int64
    batch_size::Int64
end

function (l::LinearLayer)(input)
    A, B, K = l.A, l.B, l.K
    batch_size = l.batch_size
    if length(input) == 1
        x = input[1]
        xa, xb = x[1:A,:], x[A+1:end,:]
        q = softmax(reshape(l.model(xa), B, K, batch_size), dims=2)
        output, det = calculate_forward(xa, xb, q, A, B, K, batch_size)
    elseif length(input) == 2
        x, prev_det = input
        xa, xb = x[1:A,:], x[A+1:end,:]
        q = softmax(reshape(l.model(xa), B, K, batch_size), dims=2)
        output, det = calculate_forward(xa, xb, q, A, B, K, batch_size)
        det = det .* prev_det
    end
    return output, det
end

function (l::InverseLinearLayer)(input)
    A, B, K = l.A, l.B, l.K
    batch_size = l.batch_size
    if length(input) == 1
        z = input[1]
        za, zb = z[1:A,:], z[A+1:end,:]
        q = softmax(reshape(l.model(za), B, K, batch_size), dims=2)
        output, det = calculate_inverse(za, zb, q, A, B, K, batch_size)
    elseif length(input) == 2
        z, prev_det = input
        za, zb = z[1:A,:], z[A+1:end,:]
        q = softmax(reshape(l.model(za), B, K, batch_size), dims=2)
        output, det = calculate_inverse(za, zb, q, A, B, K, batch_size)
        det = det .* prev_det
    end
    return output, det
end

Flux.@functor LinearLayer
Flux.@functor InverseLinearLayer

function loss(x,y)
    return sum((y./forward([x])[2]).^2)
end

m1 = Chain(Dense(3,7,relu), Dense(7,16))
m2 = Chain(Dense(2,5,relu), Dense(5,24))

forward = Chain(LinearLayer(m1, 3, 2, 8, 32), LinearLayer(m2, 2, 3, 8, 32))
inverse = Chain(InverseLinearLayer(m2, 2, 3, 8, 32), InverseLinearLayer(m1, 3, 2, 8, 32))

f(x::Array{Float64, 2}) = x[1,:].*x[2,:].*exp.(x[3,:]).*sin.(x[4,:]).*sqrt.(x[5,:])

x = rand(5,32)
opt = ADAM()
for i=1:100
    Flux.train!(loss, Flux.params(forward), [(x,f(x))], opt)
    println(loss(x, f(x)))
end

上面的错误说明这部分代码有问题:

output, det = calculate_forward(xa, xb, q, A, B, K, batch_size)

calculate_forward 函数是:

function calculate_forward(xa, xb, q, A::Int64, B::Int64, K::Int64, batch_size::Int64)
    bins = Int64.(ceil.(K*xb))
    ctransforms = zeros(B, batch_size)
    output = zeros(A+B, batch_size)
    dets = Buffer(zeros(batch_size))
    for s=1:batch_size
        det = 1.
        for i=1:B
            b = bins[i,s]
            for j=1:b-1
                ctransforms[i,s] += q[i,j,s]
            end
            ctransforms[i,s] += (K*xb[i,s]-(b-1))*q[i,b,s]
            det *= K*q[i,b,s]
        end
        dets[s] = det
    end
    output[1:A,:] = xa
    output[A+1:end,:] = ctransforms
    return output, copy(dets)
end

我能做些什么来解决它?

4

0 回答 0