0

我想打印分类中使用的 traindata / testdata 的标签。这是两个输入的定义(使用 deep4j)。

    InputSplit[] inputSplit = fileSplit.sample(pathFilter, splitTrainTest, 1 - splitTrainTest);
    InputSplit trainData = inputSplit[0];
    InputSplit testData = inputSplit[1];

然后像这样在 DataSetIterator 中转换:

    ImageRecordReader recordReader = new ImageRecordReader(height, width, channels, labelMaker);
    recordReader.initialize(trainData, null);
    trainIter = new RecordReaderDataSetIterator(recordReader, batchSize, 1, numLabels);

然后我想打印在这个函数的每个迭代器中找到的每个标签有多少个例子:

public void print(DataSetIterator iter){

    HashMap<String, Integer> hash = new HashMap<String, Integer>();

    while(iter.hasNext()){
        DataSet example = iter.next();
        for(int i = 0 ; i<numLabels ; i++){
            if(example.getLabels().getDouble(i)==1.){
                String label = example.getLabelName(i);
                if(hash.containsKey(label))
                    hash.put(label, hash.get(label)+1);
                else
                    hash.put(label, 1);
            }
        }
    }

    for (String label: hash.keySet()){
        System.out.println("   label : " + label.toString() + ", " + hash.get(label) + " examples");
    }
}

问题是它每个标签只显示一个示例,而应该有更多......当我不使用fileSplit.sample()该函数拆分我的数据集时,会显示正确数量的示例。有什么建议吗?

4

1 回答 1

0

如果您使用数据集,您可以使用 dataset.getFeatureMatrix() 和 dataset.getLabels() 的 toString()

如果您只想打印标签计数,可以使用 dataset.labelCounts() 我会更多地查看 dl4j javadoc: http ://deeplearning4j.org/doc

于 2016-11-11T14:26:54.327 回答