我想打印分类中使用的 traindata / testdata 的标签。这是两个输入的定义(使用 deep4j)。
InputSplit[] inputSplit = fileSplit.sample(pathFilter, splitTrainTest, 1 - splitTrainTest);
InputSplit trainData = inputSplit[0];
InputSplit testData = inputSplit[1];
然后像这样在 DataSetIterator 中转换:
ImageRecordReader recordReader = new ImageRecordReader(height, width, channels, labelMaker);
recordReader.initialize(trainData, null);
trainIter = new RecordReaderDataSetIterator(recordReader, batchSize, 1, numLabels);
然后我想打印在这个函数的每个迭代器中找到的每个标签有多少个例子:
public void print(DataSetIterator iter){
HashMap<String, Integer> hash = new HashMap<String, Integer>();
while(iter.hasNext()){
DataSet example = iter.next();
for(int i = 0 ; i<numLabels ; i++){
if(example.getLabels().getDouble(i)==1.){
String label = example.getLabelName(i);
if(hash.containsKey(label))
hash.put(label, hash.get(label)+1);
else
hash.put(label, 1);
}
}
}
for (String label: hash.keySet()){
System.out.println(" label : " + label.toString() + ", " + hash.get(label) + " examples");
}
}
问题是它每个标签只显示一个示例,而应该有更多......当我不使用fileSplit.sample()
该函数拆分我的数据集时,会显示正确数量的示例。有什么建议吗?