我是 Spark 的新手,目前正在尝试使用 deeplearning4j api 构建神经网络。培训效果很好,但我在评估时遇到了问题。我收到以下错误消息
18:16:16,206 ERROR ~ Exception in task 0.0 in stage 14.0 (TID 19)
java.lang.IllegalStateException: Network did not have same number of parameters as the broadcasted set parameters at org.deeplearning4j.spark.impl.multilayer.evaluation.EvaluateFlatMapFunction.call(EvaluateFlatMapFunction.java:75)
我似乎找不到这个问题的原因,关于 spark 和 deeplearning4j 的信息很少。我基本上从这个例子https://github.com/deeplearning4j/dl4j-spark-cdh5-examples/blob/2de0324076fb422e2bdb926a095adb97c6d0e0ca/src/main/java/org/deeplearning4j/examples/mlp/IrisLocal.java中采用了这个结构。
这是我的代码
public class DeepBeliefNetwork {
private JavaRDD<DataSet> trainSet;
private JavaRDD<DataSet> testSet;
private int inputSize;
private int numLab;
private int batchSize;
private int iterations;
private int seed;
private int listenerFreq;
MultiLayerConfiguration conf;
MultiLayerNetwork model;
SparkDl4jMultiLayer sparkmodel;
JavaSparkContext sc;
MLLibUtil mllibUtil = new MLLibUtil();
public DeepBeliefNetwork(JavaSparkContext sc, JavaRDD<DataSet> trainSet, JavaRDD<DataSet> testSet, int numLab,
int batchSize, int iterations, int seed, int listenerFreq) {
this.trainSet = trainSet;
this.testSet = testSet;
this.numLab = numLab;
this.batchSize = batchSize;
this.iterations = iterations;
this.seed = seed;
this.listenerFreq = listenerFreq;
this.inputSize = testSet.first().numInputs();
this.sc = sc;
}
public void build() {
System.out.println("input Size: " + inputSize);
System.out.println(trainSet.first().toString());
System.out.println(testSet.first().toString());
conf = new NeuralNetConfiguration.Builder().seed(seed)
.gradientNormalization(GradientNormalization.ClipElementWiseAbsoluteValue)
.gradientNormalizationThreshold(1.0).iterations(iterations).momentum(0.5)
.momentumAfter(Collections.singletonMap(3, 0.9))
.optimizationAlgo(OptimizationAlgorithm.CONJUGATE_GRADIENT).list(4)
.layer(0,
new RBM.Builder().nIn(inputSize).nOut(500).weightInit(WeightInit.XAVIER)
.lossFunction(LossFunction.RMSE_XENT).visibleUnit(RBM.VisibleUnit.BINARY)
.hiddenUnit(RBM.HiddenUnit.BINARY).build())
.layer(1,
new RBM.Builder().nIn(500).nOut(250).weightInit(WeightInit.XAVIER)
.lossFunction(LossFunction.RMSE_XENT).visibleUnit(RBM.VisibleUnit.BINARY)
.hiddenUnit(RBM.HiddenUnit.BINARY).build())
.layer(2,
new RBM.Builder().nIn(250).nOut(200).weightInit(WeightInit.XAVIER)
.lossFunction(LossFunction.RMSE_XENT).visibleUnit(RBM.VisibleUnit.BINARY)
.hiddenUnit(RBM.HiddenUnit.BINARY).build())
.layer(3, new OutputLayer.Builder(LossFunction.NEGATIVELOGLIKELIHOOD).activation("softmax").nIn(200)
.nOut(numLab).build())
.pretrain(true).backprop(false).build();
}
public void trainModel() {
model = new MultiLayerNetwork(conf);
model.init();
model.setListeners(Collections.singletonList((IterationListener) new ScoreIterationListener(listenerFreq)));
// Create Spark multi layer network from configuration
sparkmodel = new SparkDl4jMultiLayer(sc.sc(), model);
sparkmodel.fitDataSet(trainSet);
//Evaluation
Evaluation evaluation = sparkmodel.evaluate(testSet);
System.out.println(evaluation.stats());
有人对如何处理我的 JavaRDD 有建议吗?我相信问题出在那儿。
非常感谢!
编辑1
我正在使用 deeplearning4j 版本0.4-rc.10和 spark 1.5.0这是堆栈跟踪
11:03:53,088 ERROR ~ Exception in task 0.0 in stage 16.0 (TID 21 java.lang.IllegalStateException: Network did not have same number of parameters as the broadcasted set parameter
at org.deeplearning4j.spark.impl.multilayer.evaluation.EvaluateFlatMapFunction.call(EvaluateFlatMapFunction.java:75)
at org.deeplearning4j.spark.impl.multilayer.evaluation.EvaluateFlatMapFunction.call(EvaluateFlatMapFunction.java:41)
at org.apache.spark.api.java.JavaRDDLike$$anonfun$fn$4$1.apply(JavaRDDLike.scala:156)
at org.apache.spark.api.java.JavaRDDLike$$anonfun$fn$4$1.apply(JavaRDDLike.scala:156)
at org.apache.spark.rdd.RDD$$anonfun$mapPartitions$1$$anonfun$apply$17.apply(RDD.scala:706)
at org.apache.spark.rdd.RDD$$anonfun$mapPartitions$1$$anonfun$apply$17.apply(RDD.scala:706)
at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:38)
at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:297)
at org.apache.spark.rdd.RDD.iterator(RDD.scala:264)
at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:66)
at org.apache.spark.scheduler.Task.run(Task.scala:88)
at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:214)
at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1142)
at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:617)
at java.lang.Thread.run(Thread.java:745)
Driver stacktrace:
at org.apache.spark.scheduler.DAGScheduler.org$apache$spark$scheduler$DAGScheduler$$failJobAndIndependentStages(DAGScheduler.scala:1280)
at org.apache.spark.scheduler.DAGScheduler$$anonfun$abortStage$1.apply(DAGScheduler.scala:1268)
at org.apache.spark.scheduler.DAGScheduler$$anonfun$abortStage$1.apply(DAGScheduler.scala:1267)
at scala.collection.mutable.ResizableArray$class.foreach(ResizableArray.scala:59)
at scala.collection.mutable.ArrayBuffer.foreach(ArrayBuffer.scala:47)
at org.apache.spark.scheduler.DAGScheduler.abortStage(DAGScheduler.scala:1267)
at org.apache.spark.scheduler.DAGScheduler$$anonfun$handleTaskSetFailed$1.apply(DAGScheduler.scala:697)
at org.apache.spark.scheduler.DAGScheduler$$anonfun$handleTaskSetFailed$1.apply(DAGScheduler.scala:697)
at scala.Option.foreach(Option.scala:236)
at org.apache.spark.scheduler.DAGScheduler.handleTaskSetFailed(DAGScheduler.scala:697)
at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.doOnReceive(DAGScheduler.scala:1493)
at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.onReceive(DAGScheduler.scala:1455)
at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.onReceive(DAGScheduler.scala:1444)
at org.apache.spark.util.EventLoop$$anon$1.run(EventLoop.scala:48)
at org.apache.spark.scheduler.DAGScheduler.runJob(DAGScheduler.scala:567)
at org.apache.spark.SparkContext.runJob(SparkContext.scala:1813)
at org.apache.spark.SparkContext.runJob(SparkContext.scala:1933)
at org.apache.spark.rdd.RDD$$anonfun$reduce$1.apply(RDD.scala:1003)
at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:147)
at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:108)
at org.apache.spark.rdd.RDD.withScope(RDD.scala:306)
at org.apache.spark.rdd.RDD.reduce(RDD.scala:985)
at org.apache.spark.api.java.JavaRDDLike$class.reduce(JavaRDDLike.scala:375)
at org.apache.spark.api.java.AbstractJavaRDDLike.reduce(JavaRDDLike.scala:47)
at org.deeplearning4j.spark.impl.multilayer.SparkDl4jMultiLayer.evaluate(SparkDl4jMultiLayer.java:629)
at org.deeplearning4j.spark.impl.multilayer.SparkDl4jMultiLayer.evaluate(SparkDl4jMultiLayer.java:607)
at org.deeplearning4j.spark.impl.multilayer.SparkDl4jMultiLayer.evaluate(SparkDl4jMultiLayer.java:597)
at deep.deepbeliefclassifier.DeepBeliefNetwork.trainModel(DeepBeliefNetwork.java:117)
at deep.deepbeliefclassifier.DataInput.main(DataInput.java:105)