我是 deeplearning4J 的新手。我已经尝试过它的 word2vec 功能,一切都很好。但是现在我对图像分类有点困惑。我在玩这个例子:
我将“保存”标志更改为 true,并将我的模型存储到 model.bin 文件中。现在是有问题的部分(如果这听起来很愚蠢,我很抱歉,也许我在这里遗漏了一些非常明显的东西)
我创建了一个名为 AnimalClassifier 的单独类,其目的是从 model.bin 文件加载模型,从中恢复神经网络,然后使用恢复的网络对单个图像进行分类。对于这张单张图片,我创建了“temp”文件夹 -> dl4j-examples/src/main/resources/animals/temp/ 我将之前在 AnimalsClassification.java 中训练过程中使用的北极熊图片放入其中(我想确定该图像将被正确分类-因此我重用了“熊”文件夹中的图片)。
这是我试图对北极熊进行分类的代码:
protected static int height = 100;
protected static int width = 100;
protected static int channels = 3;
protected static int numExamples = 1;
protected static int numLabels = 1;
protected static int batchSize = 10;
protected static long seed = 42;
protected static Random rng = new Random(seed);
protected static int listenerFreq = 1;
protected static int iterations = 1;
protected static int epochs = 7;
protected static double splitTrainTest = 0.8;
protected static int nCores = 2;
protected static boolean save = true;
protected static String modelType = "AlexNet"; //
public static void main(String[] args) throws Exception {
String basePath = FilenameUtils.concat(System.getProperty("user.dir"), "dl4j-examples/src/main/resources/");
MultiLayerNetwork multiLayerNetwork = ModelSerializer.restoreMultiLayerNetwork(basePath + "model.bin", true);
ParentPathLabelGenerator labelMaker = new ParentPathLabelGenerator();
File mainPath = new File(System.getProperty("user.dir"), "dl4j-examples/src/main/resources/animals/temp/");
FileSplit fileSplit = new FileSplit(mainPath, NativeImageLoader.ALLOWED_FORMATS, rng);
BalancedPathFilter pathFilter = new BalancedPathFilter(rng, labelMaker, numExamples, numLabels, batchSize);
InputSplit[] inputSplit = fileSplit.sample(pathFilter, 1);
InputSplit analysedData = inputSplit[0];
ImageRecordReader recordReader = new ImageRecordReader(height, width, channels);
recordReader.initialize(analysedData);
DataSetIterator dataIter = new RecordReaderDataSetIterator(recordReader, batchSize, 0, 4);
while (dataIter.hasNext()) {
DataSet testDataSet = dataIter.next();
String expectedResult = testDataSet.getLabelName(0);
List<String> predict = multiLayerNetwork.predict(testDataSet);
String modelResult = predict.get(0);
System.out.println("\nFor example that is labeled " + expectedResult + " the model predicted " + modelResult + "\n\n");
}
}
运行此程序后,我收到错误:
此数据集上未定义标签名称。添加标签名称以便使用带有 id 的 getLabelName。在 org.nd4j.linalg.dataset.DataSet.getLabelName(DataSet.java:1106) 在 org.deeplearning4j.examples.convolution.AnimalClassifier.main(AnimalClassifier.java:68)
我可以看到 MultiLayerNetwork.java 中有一个方法 public void setLabels(INDArray labels) 但我不知道如何使用(尤其是当它作为参数 INDArray 时)。
我也很困惑为什么我必须在 RecordReaderDataSetIterator 的构造函数中指定可能的标签数量。我希望该模型已经知道要使用哪些标签(它不应该使用在训练期间自动使用的标签吗?)。我想,也许我以完全错误的方式加载图片......
总而言之,我想简单地实现以下目标:
- 从模型恢复网络(这是有效的)
- 加载要分类的图像(也可以工作)
- 使用训练期间使用的相同标签(熊、鹿、鸭、乌龟)(棘手的部分)对该图像进行分类
提前感谢您的帮助或任何提示!