我正在尝试对Flux
下面提到的代码的性能进行基准测试:
#model
using Flux
vgg19() = Chain(
Conv((3, 3), 3 => 64, relu, pad=(1, 1), stride=(1, 1)),
Conv((3, 3), 64 => 64, relu, pad=(1, 1), stride=(1, 1)),
MaxPool((2,2)),
Conv((3, 3), 64 => 128, relu, pad=(1, 1), stride=(1, 1)),
Conv((3, 3), 128 => 128, relu, pad=(1, 1), stride=(1, 1)),
MaxPool((2,2)),
Conv((3, 3), 128 => 256, relu, pad=(1, 1), stride=(1, 1)),
Conv((3, 3), 256 => 256, relu, pad=(1, 1), stride=(1, 1)),
Conv((3, 3), 256 => 256, relu, pad=(1, 1), stride=(1, 1)),
MaxPool((2,2)),
Conv((3, 3), 256 => 512, relu, pad=(1, 1), stride=(1, 1)),
Conv((3, 3), 512 => 512, relu, pad=(1, 1), stride=(1, 1)),
Conv((3, 3), 512 => 512, relu, pad=(1, 1), stride=(1, 1)),
MaxPool((2,2)),
Conv((3, 3), 512 => 512, relu, pad=(1, 1), stride=(1, 1)),
Conv((3, 3), 512 => 512, relu, pad=(1, 1), stride=(1, 1)),
Conv((3, 3), 512 => 512, relu, pad=(1, 1), stride=(1, 1)),
BatchNorm(512),
MaxPool((2,2)),
flatten,
Dense(512, 4096, relu),
Dropout(0.5),
Dense(4096, 4096, relu),
Dropout(0.5),
Dense(4096, 10),
softmax
)
#data
using MLDatasets: CIFAR10
using Flux: onehotbatch
# Data comes pre-normalized in Julia
trainX, trainY = CIFAR10.traindata(Float32)
testX, testY = CIFAR10.testdata(Float32)
# One hot encode labels
trainY = onehotbatch(trainY, 0:9)
testY = onehotbatch(testY, 0:9)
#training
using Flux: crossentropy, @epochs
using Flux.Data: DataLoader
model = vgg19()
opt = Momentum(.001, .9)
loss(x, y) = crossentropy(model(x), y)
data = DataLoader(trainX, trainY, batchsize=64)
@epochs 100 Flux.train!(loss, params(model), data, opt)
我尝试使用内置tick()
和tock()
功能来测量时间。但是,这给了执行密集比较的基本时间并且效率不高。社区中的许多开发人员都建议使用BenchmarkTools.jl
package 来对代码进行基准测试。但是当我尝试ScikitLearn Model
在 REPL 中进行基准测试时,它会产生警告;
WARNING: redefinition of constant LogisticRegression. This may fail, cause incorrect answers, or produce other errors.
REPL
同样,我尝试在using中对上述代码进行基准测试,@btime
但它会引发此错误:
julia> using BenchmarkTools
julia> @btime include("C:/Users/user/code.jl")
[ Info: Epoch 1
WARNING: both Flux and BenchmarkTools export "params"; uses of it in module Main must be qualified
ERROR: LoadError: UndefVarError: params not defined
我可以知道执行代码详细基准测试的最佳方法是什么?
提前致谢。