1

下面提到的代码取自model-zoo。我正在尝试使用通量库在 julia中运行vgg19 教程。

代码:

#model
using Flux
vgg19() = Chain(            
    Conv((3, 3), 3 => 64, relu, pad=(1, 1), stride=(1, 1)),
    Conv((3, 3), 64 => 64, relu, pad=(1, 1), stride=(1, 1)),
    MaxPool((2,2)),
    Conv((3, 3), 64 => 128, relu, pad=(1, 1), stride=(1, 1)),
    Conv((3, 3), 128 => 128, relu, pad=(1, 1), stride=(1, 1)),
    MaxPool((2,2)),
    Conv((3, 3), 128 => 256, relu, pad=(1, 1), stride=(1, 1)),
    Conv((3, 3), 256 => 256, relu, pad=(1, 1), stride=(1, 1)),
    Conv((3, 3), 256 => 256, relu, pad=(1, 1), stride=(1, 1)),
    MaxPool((2,2)),
    Conv((3, 3), 256 => 512, relu, pad=(1, 1), stride=(1, 1)),
    Conv((3, 3), 512 => 512, relu, pad=(1, 1), stride=(1, 1)),
    Conv((3, 3), 512 => 512, relu, pad=(1, 1), stride=(1, 1)),
    MaxPool((2,2)),
    Conv((3, 3), 512 => 512, relu, pad=(1, 1), stride=(1, 1)),
    Conv((3, 3), 512 => 512, relu, pad=(1, 1), stride=(1, 1)),
    Conv((3, 3), 512 => 512, relu, pad=(1, 1), stride=(1, 1)),
    BatchNorm(512),
    MaxPool((2,2)),
    flatten,
    Dense(512, 4096, relu),
    Dropout(0.5),
    Dense(4096, 4096, relu),
    Dropout(0.5),
    Dense(4096, 10),
    softmax
)

#data

using MLDatasets: CIFAR10
using Flux: onehotbatch
# Data comes pre-normalized in Julia
trainX, trainY = CIFAR10.traindata(Float64)
testX, testY = CIFAR10.testdata(Float64)
# One hot encode labels
trainY = onehotbatch(trainY, 0:9)
testY = onehotbatch(testY, 0:9)

#training

using Flux: crossentropy, @epochs
using Flux.Data: DataLoader
model = vgg19()
opt = Momentum(.001, .9)
loss(x, y) = crossentropy(model(x), y)
data = DataLoader(trainX, trainY, batchsize=64)
@epochs 100 Flux.train!(loss, params(model), data, opt)

当我在 IJulia 上执行此文件时,会引发以下错误:

MethodError: no method matching ∇maxpool(::Array{Float32,4}, ::Array{Float64,4}, ::Array{Float64,4}, ::PoolDims{2,(2, 2),(2, 2),(0, 0, 0, 0),(1, 1)})
Closest candidates are:
  ∇maxpool(::AbstractArray{T,N}, !Matched::AbstractArray{T,N}, !Matched::AbstractArray{T,N}, ::PoolDims; kwargs...) where {T, N}

请为此错误提出一些解决方案,如果可能,请提供简要说明或参考。提前致谢!

4

1 回答 1

1

正如@mcabbott 所提到的,这个问题与数据的输入类型有关。这可以通过将typefromFloat64更改Float32为以下#data部分中提到的参数来解决。

trainX, trainY = CIFAR10.traindata(Float32)
testX, testY = CIFAR10.testdata(Float32)
于 2021-01-26T17:46:11.817 回答