2

我努力让我的神经网络正常工作。我有一个关于是否患有疟疾的细胞图片数据集(https://www.kaggle.com/iarunava/cell-images-for-detecting-malaria)。我这样安排我的数据:

X_training a matrix of dimension 30000×2668, type : Array{Float64,2}
Y_training a matrix of dimension 1×2668, type : Array{Float64,2}

X_tests和相同Y_tests

我的简单神经网络:

function simple_nn(X_tests, Y_tests, X_training, Y_training)
    input = 100*100*3
    hl1 = 32
    m = Chain(
      Dense(input, 32, relu),
      Dense(32, 2),
      softmax) |> gpu

    loss(x, y) = crossentropy(m(x), y)

    accuracy(x, y) = mean(onecold(m(x)) .== onecold(y))

    dataset = repeated((X_training, Y_training), 2)
    evalcb = () -> @show(loss(X_training, Y_training))
    opt = ADAM(params(m))

    Flux.train!(loss, dataset, opt, cb = throttle(evalcb, 10))

    println("acc X,Y ", accuracy(X_training, Y_training))

    # Error here
    println("acc tX, tY ", accuracy(X_tests, Y_tests))

end

我的错误是:

ERROR: OutOfMemoryError()
...
Stacktrace:
 [1] * at ./boot.jl:396 [inlined]
 [2] _forward at /home/.../.julia/packages/Flux/8XpDt/src/tracker/lib/array.jl:361 [inlined]
 [3] #track#1 at /home/.../.julia/packages/Flux/8XpDt/src/tracker/Tracker.jl:51 [inlined]
 [4] track at /home/.../.julia/packages/Flux/8XpDt/src/tracker/Tracker.jl:51 [inlined]
 [5] * at /home/.../.julia/packages/Flux/8XpDt/src/tracker/lib/array.jl:349 [inlined]
 [6] Dense at /home/.../.julia/packages/Flux/8XpDt/src/layers/basic.jl:82 [inlined]
 [7] Dense at /home/.../.julia/packages/Flux/8XpDt/src/layers/basic.jl:122 [inlined]
 [8] (::Dense{typeof(relu),TrackedArray{…,Array{Float32,2}},TrackedArray{…,Array{Float32,1}}})(::Array{Float64,2}) at /home/.../.julia/packages/Flux/8XpDt/src/layers/basic.jl:125
 [9] applychain(::Tuple{Dense{typeof(relu),TrackedArray{…,Array{Float32,2}},TrackedArray{…,Array{Float32,1}}},Dense{typeof(identity),TrackedArray{…,Array{Float32,2}},TrackedArray{…,Array{Float32,1}}},typeof(softmax)}, ::Array{Float64,2}) at /home/.../.julia/packages/Flux/8XpDt/src/layers/basic.jl:31
 [10] Chain at /home/.../.julia/packages/Flux/8XpDt/src/layers/basic.jl:33 [inlined]
 [11] (::getfield(Main, Symbol("#accuracy#49")){Chain{Tuple{Dense{typeof(relu),TrackedArray{…,Array{Float32,2}},TrackedArray{…,Array{Float32,1}}},Dense{typeof(identity),TrackedArray{…,Array{Float32,2}},TrackedArray{…,Array{Float32,1}}},typeof(softmax)}}})(::Array{Float64,2}, ::Array{Float64,2}) at /home/.../neural_net.jl:108
 [12] simple_nn(::Array{Float64,2}, ::Array{Float64,2}, ::Array{Float64,2}, ::Array{Float64,2}) at /home/.../neural_net.jl:118
 [13] top-level scope at util.jl:156

我的矩阵太大了吗?

4

0 回答 0