package com.example.minwoo_k.neural_network;
import android.os.AsyncTask;
import android.support.v7.app.AppCompatActivity;
import android.os.Bundle;
import android.util.Log;
import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator;
import org.deeplearning4j.eval.Evaluation;
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.layers.DenseLayer;
import org.deeplearning4j.nn.conf.layers.OutputLayer;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
import org.deeplearning4j.util.ModelSerializer;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.factory.Nd4j;
import org.reflections.vfs.CommonsVfs2UrlType;
import java.io.File;
import java.io.IOException;
import java.io.InputStream;
import static android.R.id.input;
import static org.reflections.Reflections.log;
public class MainActivity extends AppCompatActivity {
@Override
protected void onCreate(Bundle savedInstanceState) {
super.onCreate(savedInstanceState);
setContentView(R.layout.activity_main);
AsyncTask.execute(new Runnable() {
@Override
public void run() {
try {
createAndUseNetwork();
} catch (IOException e) {
e.printStackTrace();
}
}
});
}
private void createAndUseNetwork() throws IOException {
DenseLayer inputLayer = new DenseLayer.Builder() // Input Layer
.nIn(784)
.nOut(200)
.name("Input")
.activation(Activation.SIGMOID) // Sigmoid Activation function
.build();
DenseLayer hiddenLayer = new DenseLayer.Builder() // Hidden Layer
.nIn(200)
.nOut(10)
.name("Hidden")
.activation(Activation.SIGMOID) // Sigmoid Activation function
.build();
OutputLayer outputLayer = new OutputLayer.Builder() // Output Layer
.nIn(10)
.nOut(10)
.name("Output")
.activation(Activation.SOFTMAX) // Softmax Activation function
.build();
NeuralNetConfiguration.Builder nncBuilder = new NeuralNetConfiguration.Builder();
nncBuilder.iterations(5);
nncBuilder.learningRate(0.05); // Learning Rate
nncBuilder.weightInit(WeightInit.XAVIER);
nncBuilder.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT); // use SGD
NeuralNetConfiguration.ListBuilder listBuilder = nncBuilder.list();
listBuilder.layer(0, inputLayer);
listBuilder.layer(1, hiddenLayer);
listBuilder.layer(2, outputLayer);
listBuilder.backprop(true); // backpropagation
Log.d("ANN","****************Create ANN********************");
MultiLayerNetwork myNetwork = new MultiLayerNetwork(listBuilder.build());
myNetwork.init();
myNetwork.setListeners(new ScoreIterationListener(1));
Log.d("ANN","****************Get Data********************");
DataSetIterator mnistTrain = new MnistDataSetIterator(500, 10000, true);
DataSetIterator mnistTest = new MnistDataSetIterator(500, 100, true);
Log.d("ANN","****************Train ANN********************");
myNetwork.fit(mnistTrain);
Log.d("ANN","****************Evaluate ANN********************");
Evaluation eval = new Evaluation(10); //create an evaluation object with 10 possible classes
while(mnistTest.hasNext()){
DataSet next = mnistTest.next();
INDArray output = myNetwork.output(next.getFeatureMatrix()); //get the networks prediction
eval.eval(next.getLabels(), output); //check the prediction against the true class
}
log.info(eval.stats());
log.info("****************Example finished********************");
}
}
这是我的程序的完整源代码,我无法读取 mnist 数据。如何获取 mnist 数据集?
12-15 12:26:06.526 3910-3930/com.example.minwoo_k.neural_network W/System.err: java.io.IOException: 无法 mkdir /MNIST 12-15 12:26:06.526 3910-3930/com。 example.minwoo_k.neural_network W/System.err: at org.deeplearning4j.base.MnistFetcher.downloadAndUntar(MnistFetcher.java:66) 12-15 12:26:06.529 3910-3930/com.example.minwoo_k.neural_network W/System .err:在 org.deeplearning4j.datasets.fetchers.MnistDataFetcher.(MnistDataFetcher.java:65) 12-15 12:26:06.529 3910-3930/com.example.minwoo_k.neural_network W/System.err:在 org.deeplearning4j .datasets.iterator.impl.MnistDataSetIterator.(MnistDataSetIterator.java:65) 12-15 12:26:06.529 3910-3930/com.example.minwoo_k.neural_network W/System.err:在 org.deeplearning4j.datasets.iterator。 impl.MnistDataSetIterator.(MnistDataSetIterator.java:43) 12-15 12:26:06。529 3910-3930/com.example.minwoo_k.neural_network W/System.err:在 com.example.minwoo_k.neural_network.MainActivity.createAndUseNetwork(MainActivity.java:93) 12-15 12:26:06.529 3910-3930/com .example.minwoo_k.neural_network W/System.err:在 com.example.minwoo_k.neural_network.MainActivity.access$000(MainActivity.java:33) 12-15 12:26:06.531 3910-3930/com.example.minwoo_k。神经网络 W/System.err:在 com.example.minwoo_k.neural_network.MainActivity$1.run(MainActivity.java:44) 12-15 12:26:06.531 3910-3930/com.example.minwoo_k.neural_network W/System。错误:在 android.os.AsyncTask$SerialExecutor$1.run(AsyncTask.java:245) 12-15 12:26:06.532 3910-3930/com.example.minwoo_k.neural_network W/System.err:在 java.util。 concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1162) 12-15 12:26:06。532 3910-3930/com.example.minwoo_k.neural_network W/System.err:在 java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:636) 12-15 12:26:06.532 3910-3930/com .example.minwoo_k.neural_network W/System.err:在 java.lang.Thread.run(Thread.java:764)
这是我的 Logcat 记录。我该如何解决这个问题?