我将使用 deeplearning4j 平台上的示例对自己的图像进行异常检测。我像这样更改代码:
int rngSeed=123;
Random rnd = new Random(rngSeed);
int width=28;
int height=28;
int batchSize = 128;
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
.seed(12345)
.iterations(1)
.weightInit(WeightInit.XAVIER)
.updater(Updater.ADAGRAD)
.activation(Activation.RELU)
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
.learningRate(0.05)
.regularization(true).l2(0.0001)
.list()
.layer(0, new DenseLayer.Builder().nIn(784).nOut(250)
.build())
.layer(1, new DenseLayer.Builder().nIn(250).nOut(10)
.build())
.layer(2, new DenseLayer.Builder().nIn(10).nOut(250)
.build())
.layer(3, new OutputLayer.Builder().nIn(250).nOut(784)
.lossFunction(LossFunctions.LossFunction.MSE)
.build())
.pretrain(false).backprop(true)
.build();
MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.setListeners(Collections.singletonList((IterationListener) new ScoreIterationListener(1)));
File trainData = new File("mnist_png/training");
FileSplit fsTrain = new FileSplit(trainData, NativeImageLoader.ALLOWED_FORMATS, rnd);
ImageRecordReader recorderReader = new ImageRecordReader(height, width);
recorderReader.initialize(fsTrain);
DataSetIterator dataIt = new RecordReaderDataSetIterator(recorderReader, batchSize);
List<INDArray> featuresTrain = new ArrayList<>();
while(dataIt.hasNext()){
DataSet ds = dataIt.next();
featuresTrain.add(ds.getFeatureMatrix());
}
System.out.println("************ training **************");
int nEpochs = 30;
for( int epoch=0; epoch<nEpochs; epoch++ ){
for(INDArray data : featuresTrain){
net.fit(data,data);
}
System.out.println("Epoch " + epoch + " complete");
}
它在训练时抛出了异常:
Exception in thread "main" org.deeplearning4j.exception.DL4JInvalidInputException: Input that is not a matrix; expected matrix (rank 2), got rank 4 array with shape [128, 1, 28, 28]
at org.deeplearning4j.nn.layers.BaseLayer.preOutput(BaseLayer.java:363)
at org.deeplearning4j.nn.layers.BaseLayer.activate(BaseLayer.java:384)
at org.deeplearning4j.nn.layers.BaseLayer.activate(BaseLayer.java:405)
at org.deeplearning4j.nn.multilayer.MultiLayerNetwork.activationFromPrevLayer(MultiLayerNetwork.java:590)
at org.deeplearning4j.nn.multilayer.MultiLayerNetwork.feedForwardToLayer(MultiLayerNetwork.java:713)
at org.deeplearning4j.nn.multilayer.MultiLayerNetwork.computeGradientAndScore(MultiLayerNetwork.java:1821)
at org.deeplearning4j.optimize.solvers.BaseOptimizer.gradientAndScore(BaseOptimizer.java:151)
at org.deeplearning4j.optimize.solvers.StochasticGradientDescent.optimize(StochasticGradientDescent.java:54)
at org.deeplearning4j.optimize.Solver.optimize(Solver.java:51)
at org.deeplearning4j.nn.multilayer.MultiLayerNetwork.fit(MultiLayerNetwork.java:1443)
at org.deeplearning4j.nn.multilayer.MultiLayerNetwork.fit(MultiLayerNetwork.java:1408)
at org.deeplearning4j.examples.dataExamples.AnomalyTest.main(AnomalyTest.java:86)
似乎我的输入数据集有 4 列,而它只需要 2 列,所以问题是如何转换 imagerecorderread 或其他东西以使其正常运行?