在我自己的项目中执行 MNIST 示例时,我收到以下错误:
o.n.l.f.Nd4jBackend - Loaded [CpuBackend] backend
Exception in thread "main" java.lang.NoSuchFieldError: HALF
at org.nd4j.linalg.factory.Nd4j.initWithBackend(Nd4j.java:5593)
at org.nd4j.linalg.factory.Nd4j.initContext(Nd4j.java:5554)
at org.nd4j.linalg.factory.Nd4j.<clinit>(Nd4j.java:189)
at org.deeplearning4j.nn.conf.NeuralNetConfiguration$Builder.seed(NeuralNetConfiguration.java:624)
at com.baus.visualagent.DigitTrainer.startTraining(DigitTrainer.java:47)
at com.baus.visualagent.App.main(App.java:17)
at sun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)
at sun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)
at sun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)
at java.lang.reflect.Method.invoke(Method.java:497)
at com.intellij.rt.execution.application.AppMain.main(AppMain.java:147)
如果我能知道是什么导致了这个问题,那就太好了。是因为 POM 文件配置不正确还是其他原因?
POM 文件内容为:
<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
<modelVersion>4.0.0</modelVersion>
<groupId>com.baus.visualagent</groupId>
<artifactId>Visual Agent</artifactId>
<version>1.0-SNAPSHOT</version>
<packaging>jar</packaging>
<name>Visual Agent</name>
<url>http://maven.apache.org</url>
<properties>
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
<nd4j.backend>nd4j-native-platform</nd4j.backend>
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
<shadedClassifier>bin</shadedClassifier>
<java.version>1.8</java.version>
<nd4j.version>0.7.2</nd4j.version>
<dl4j.version>0.7.2</dl4j.version>
<datavec.version>0.7.2</datavec.version>
<arbiter.version>0.7.2</arbiter.version>
<rl4j.version>0.7.2</rl4j.version>
<guava.version>19.0</guava.version>
<logback.version>1.1.7</logback.version>
<jfreechart.version>1.0.13</jfreechart.version>
<jcommon.version>1.0.23</jcommon.version>
<maven-shade-plugin.version>2.4.3</maven-shade-plugin.version>
<exec-maven-plugin.version>1.4.0</exec-maven-plugin.version>
<maven.minimum.version>3.3.1</maven.minimum.version>
</properties>
<dependencies>
<dependency>
<groupId>junit</groupId>
<artifactId>junit</artifactId>
<version>3.8.1</version>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.nd4j</groupId>
<artifactId>nd4j-native-platform</artifactId>
<version>${nd4j.version}</version>
</dependency>
<dependency>
<groupId>org.deeplearning4j</groupId>
<artifactId>deeplearning4j-core</artifactId>
<version>0.7.2</version>
</dependency>
<dependency>
<groupId>org.deeplearning4j</groupId>
<artifactId>deeplearning4j-nlp</artifactId>
<version>0.7.2</version>
</dependency>
<dependency>
<groupId>org.nd4j</groupId>
<artifactId>canova-nd4j-image</artifactId>
<version>0.0.0.17</version>
</dependency>
<dependency>
<groupId>org.nd4j</groupId>
<artifactId>canova-nd4j-codec</artifactId>
<version>0.0.0.17</version>
</dependency>
<dependency>
<groupId>org.nd4j</groupId>
<artifactId>nd4j-api</artifactId>
<version>0.7.2</version>
</dependency>
<dependency>
<groupId>org.nd4j</groupId>
<artifactId>nd4j-blas</artifactId>
<version>unknown</version>
</dependency>
<dependency>
<groupId>org.nd4j</groupId>
<artifactId>nd4j-x86</artifactId>
<version>0.4-rc3.8</version>
</dependency>
<dependency>
<groupId>org.nd4j</groupId>
<artifactId>nd4j-examples</artifactId>
<version>0.0.3.5.4</version>
</dependency>
<dependency>
<groupId>org.deeplearning4j</groupId>
<artifactId>dl4j-examples</artifactId>
<version>0.7-SNAPSHOT</version>
</dependency>
<dependency>
<groupId>org.deeplearning4j</groupId>
<artifactId>deeplearning4j-modelimport</artifactId>
<version>0.7.2</version>
</dependency>
<dependency>
<groupId>org.deeplearning4j</groupId>
<artifactId>deeplearning4j-nn</artifactId>
<version>0.7.2</version>
</dependency>
<dependency>
<groupId>org.deeplearning4j</groupId>
<artifactId>deeplearning4j-ui</artifactId>
<version>0.6.0</version>
</dependency>
</dependencies>
<build>
<plugins>
<plugin>
<artifactId>maven-enforcer-plugin</artifactId>
<executions>
<execution>
<id>enforce-default</id>
<goals>
<goal>enforce</goal>
</goals>
<configuration>
<rules>
<requireMavenVersion>
<version>[${maven.minimum.version},)</version>
<message>********** Minimum Maven Version is ${maven.minimum.version}. Please upgrade Maven before continuing (run "mvn --version" to check). **********</message>
</requireMavenVersion>
</rules>
</configuration>
</execution>
</executions>
</plugin>
</plugins>
</build>
该代码由 2 个类组成:
- 名为 App 的主类
- 一个名为 DigitTrainer 的神经网络类
App的代码如下:
public class App
{
public static void main( String[] args )throws IOException
{
DigitTrainer t=new DigitTrainer();
t.startTraining();
t.startTesting();
}
}
DigitTrainer 的代码如下:
public class DigitTrainer {
private int layers;
private int rows;
private int cols;
private int out;
private int batch;
private int seed;
private int epochs;
private DataSetIterator test,train;
private MultiLayerConfiguration config;
private MultiLayerNetwork ann;
public DigitTrainer() throws IOException {
rows=28;
cols=28;
out=10;
batch=128;
seed=123;
epochs=20;
train=new MnistDataSetIterator(batch,true,seed);
test=new MnistDataSetIterator(batch,false,seed);
}
public void startTraining(){
config = new NeuralNetConfiguration.Builder()
.seed(seed)
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
.iterations(1)
.learningRate(0.006)
.updater(Updater.NESTEROVS)
.momentum(0.9)
.regularization(true).l2(1e-4)
.list()
.layer(0, new DenseLayer.Builder()
.nIn(rows*cols)
.nOut(1000)
.activation(Activation.IDENTITY)
.weightInit(WeightInit.XAVIER)
.build())
.layer(1, new DenseLayer.Builder()
.nIn(1000)
.nOut(out)
.activation(Activation.SIGMOID)
.weightInit(WeightInit.XAVIER)
.build())
.pretrain(false)
.backprop(true)
.build();
ann=new MultiLayerNetwork(config);
ann.init();
ann.setListeners(new ScoreIterationListener(1));
System.out.println("\n******Beginning Training******\n");
for(int i=0;i<epochs;i++){
ann.fit(train);
}
System.out.println("\n******Model Trained******\n");
}
public void startTesting(){
System.out.println("\n******Starting Testing******\n");
Evaluation e=new Evaluation(out);
while(test.hasNext()){
DataSet next = test.next();
INDArray x=ann.output(next.getFeatureMatrix());
e.eval(next.getLabels(),x);
}
System.out.println(e.stats());
}
}