我正在尝试使用 deeplearning4j 训练神经网络。但是我收到了这个我无法解释的错误消息:
java.lang.reflect.InvocationTargetException
at sun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)
at sun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)
at sun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)
at java.lang.reflect.Method.invoke(Method.java:498)
at org.codehaus.mojo.exec.ExecJavaMojo$1.run(ExecJavaMojo.java:294)
at java.lang.Thread.run(Thread.java:745)
Caused by: java.lang.IllegalArgumentException: Unable to get linear index >= 1
at org.nd4j.linalg.api.ndarray.BaseNDArray.getDouble(BaseNDArray.java:3275)
at org.deeplearning4j.eval.Evaluation.eval(Evaluation.java:197)
at mypackage.myclass.main(Learn.java:77)
我的数据在一个 csv 文件中,它是 64 个数字(值 0、1、2、3)和一个值 -1000 到 1000(浮点数)的标签。
例如:
2,3,2,2,1,1,2,3,0,1,1,2,3,1,1,0,0,0,2,2,0,0,3,1,0,1,3,1,1,1,2,2,2,2,2,2,3, 2,2,2,2,3,3,1,2,2,1,3,0,0,2,3,2,3,2,0,0,3,0,1,1,3,3,2,-228.0
我使用此代码加载 csv 文件并训练网络:
RecordReader recordReader = new CSVRecordReader(0, ",");
recordReader.initialize(new FileSplit(new File("data.csv")));
DataSetIterator iterator = new RecordReaderDataSetIterator(recordReader, new DoubleWritableCo nverter(), 600000, 64, 64, true);
DataSet allData = iterator.next();
SplitTestAndTrain testAndTrain = allData.splitTestAndTrain(0.9);
DataSet trainingData = testAndTrain.getTrain();
DataSet testData = testAndTrain.getTest();
//We need to normalize our data. We'll use NormalizeStandardize (which gives us mean 0, unit variance):
DataNormalization normalizer = new NormalizerStandardize();
normalizer.fit(trainingData); //Collect the statistics (mean/stdev) from the traini ng data. This does not modify the input data
normalizer.transform(trainingData); //Apply normalization to the training data
normalizer.transform(testData); //Apply normalization to the test data. This is using statistics calculated from the *training* set
long seed = 123;
int inputNum = 64;
int hiddenNum = 64;
int outputNum = 1;
int iterations = 1;
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
.seed(seed)
.activation("tanh")
.iterations(iterations)
.weightInit(WeightInit.XAVIER)
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
.learningRate(0.1)
.regularization(true).l2(1e-4)
.list()
.layer(0, new DenseLayer.Builder().nIn(inputNum).nOut(hiddenNum).build())
.layer(1, new OutputLayer.Builder(LossFunctions.LossFunction.MSE)
.activation("identity")
.nIn(hiddenNum).nOut(outputNum).build())
.backprop(true).pretrain(false)
.build();
MultiLayerNetwork model = new MultiLayerNetwork(conf);
model.init();
model.setListeners(new ScoreIterationListener(100));
model.fit(trainingData);
//evaluate the model on the test set
Evaluation eval = new Evaluation(2);
INDArray output = model.output(testData.getFeatureMatrix());
eval.eval(testData.getLabels(), output); <---- this is line 77, where the error occurs
System.out.println(eval.stats());
recordReader.close();
这个错误是什么意思,我该如何解决这个问题?