0

我正在构建一个决策树分类器,我发现了这种计算信息增益的方法。这可能是一个愚蠢的问题,但我想知道这种方法中的拆分是针对数字属性还是分类属性?我很困惑,因为我认为阈值(中位数)用于数字拆分,但此方法使用字符串值。

任何帮助表示赞赏。

这是代码:

    public static double getInfoGain(int f, ArrayList<String[]> dataSubset) {
            double entropyBefore = getEntropy(dataSubset); //Entropy before split
            if(entropyBefore != 0){ // Calculate information gain if entropy is not 0
                String threshold = thresholdMap.get(f); // Get threshold value of the feature
                ArrayList<String[]> leftData = new ArrayList<String[]>();
                ArrayList<String[]> rightData = new ArrayList<String[]>();
                for(String[] d : dataSubset) {
                    if(d[f].equals(threshold)) {
                        leftData.add(d); // If feature value of data == threshold, add it to leftData
                    } else {
                        rightData.add(d); // If feature value of data != threshold, add it to leftData
                    }
                }
                if(leftData.size() > 0 && rightData.size() > 0) {
                    double leftProb = (double)leftData.size()/dataSubset.size(); 
                    double rightProb = (double)rightData.size()/dataSubset.size();
                    double entropyLeft = getEntropy(leftData); //Entropy after split - left
                    double entropyRight = getEntropy(rightData); //Entropy after split - right
                    double gain = entropyBefore - (leftProb * entropyLeft) - (rightProb * entropyRight);
                    return gain;
                } else { // If entropy = 0 on either subsets of data, return 0
                    return 0;
                }
            } else { // If entropy = 0 before split, return 1
                return -1;
            }
        }
4

1 回答 1

0

尽管您指向的代码使用了阈值的术语,但如果您查看注释,它会以分类或二进制方式使用它们。

if(d[f].equals(threshold)) {
   leftData.add(d); // If feature value of data == threshold, add it to leftData
} else {
   rightData.add(d); // If feature value of data != threshold, add it to leftData
}

我强烈建议您查看教科书或维基百科中的算法作为参考,而不是直接阅读代码。或者,如果您发现自己需要代码示例,我会在 Github 上寻找更高质量(三个维度)的存储库。

  1. 您想使用明确的许可证来学习代码。在许多地方,没有许可证就等于是专有的,尽管 Github 隐含着开源性质,但这在法律上并不准确。
  2. 你想研究人们使用的代码。github 上有更多的决策树算法实现,它们的星数和问题不只为零。
  3. 如果做不到这一点,你想研究有测试的代码(一个指示和一个测试它是否真的适合你自己的机会)。

理想情况下,您需要许多信任迹象。如果我去 github,搜索决策树,检查 Java,按大多数星级排序,我会自己查看sanity /quickmlsaebyn/java-decision-tree之一。

于 2017-12-06T23:30:44.437 回答