我目前正在使用 nd4j 和 dl4j 来实现一些深度学习算法。但是,我首先无法让 datavec + dl4j 工作。
这是我的图像转换器:
public class ImageConverter {
private static Logger log = LoggerFactory.getLogger(ImageConverter.class);
public DataSetIterator Convert() throws IOException, InterruptedException {
log.info("Start to convert images...");
File parentDir = new File(System.getProperty("user.dir"), "src/main/resources/images/");
ParentPathLabelGenerator parentPathLabelGenerator = new ParentPathLabelGenerator();
ImageRecordReader recordReader = new ImageRecordReader(28,28,1,parentPathLabelGenerator);
FileSplit fs = new FileSplit(parentDir);
InputSplit[] filesInDirSplit = fs.sample(null, 100);
recordReader.initialize(filesInDirSplit[0]);
DataSetIterator dataIter = new RecordReaderDataSetIterator(recordReader, 2, 1, 2);
log.info("Image convert finished.");
return dataIter;
}
}
这是主要课程:
ImageConverter icv = new ImageConverter();
DataSetIterator dataSetIterator = icv.Convert();
log.info("Build model....");
int numEpochs = 10;
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
.seed(123)
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
.iterations(1)
.learningRate(0.006)
.updater(Updater.NESTEROVS).momentum(0.9)
.regularization(true).l2(1e-4)
.list()
.layer(0, new ConvolutionLayer.Builder(5, 5)
.nIn(28 * 28)
.stride(1, 1)
.nOut(20)
.activation("identity")
.build())
.layer(1, new OutputLayer.Builder(LossFunction.NEGATIVELOGLIKELIHOOD)
.nIn(24 * 24)
.nOut(2)
.activation("softmax")
.build())
.pretrain(false)
.backprop(true)
.build();
MultiLayerNetwork model = new MultiLayerNetwork(conf);
model.init();
model.setListeners(new ScoreIterationListener(1));
log.info("Train model....");
for( int i=0; i<numEpochs; i++ ){
model.fit(dataSetIterator);
}
在图像文件夹中,我有一些灰度 28x28 图像分别位于子文件夹a
中b
。
然而,Exception in thread "main" java.lang.IllegalStateException: Unable to get number of of rows for a non 2d matrix
被抛出。
通过查看数据dataSetIterator.next().toString()
,它类似于:
[[[...],
...
]]
=================OUTPUT==================
[[1.00, 0.00],
[1.00, 0.00]]
此外,输出dataSetIterator.next().get(0).toString()
为
[[[[...],
...
]]]
=================OUTPUT==================
[1.00, 0.00]
对于示例中的 mnisterIterator,mnisterIterator.next().toString()
应该类似于:
[[...]...]
=================OUTPUT==================
[[0.00, 0.00, 1.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00],
...]
我从中推断出dataSetIterator
我返回的包含格式错误的数据。
有人知道如何解决吗?