1

dict()下面是我在 Python 中训练的决策树 ( ) 的简化示例:

tree= {'Age': {'> 55': 0.4, '< 18': {'Income': {'high': 0, 'low': 0.2}}, 
               '18-35': 0.25, '36-55': {'Marital_Status': {'single': {'Income': 
               {'high': 0, 'low': 0.1}}, 'married': 0.05}}}}

叶节点(框)中的数字表示类标签(例如,TRUE)出现在该节点中的概率。从视觉上看,树看起来像这样:

在此处输入图像描述

我正在尝试编写一个通用的后修剪算法,该算法将值小于0.3其父节点的节点合并。因此,具有0.3阈值的生成树在绘制时将如下所示:

在此处输入图像描述

在第二个图中,请注意Income节点 atAge<18现在已合并到根节点Age。并且 由于其所有叶节点(在多个级别)的总和小于 0.3 ,Age=36-55, Marital_Staus因此已合并到。Age

这是我想出的不完整的伪代码(到目前为止):

def post_prune  (dictionary, threshold):

    for k in dictionary.keys():

        if isinstance(dictionary[k], dict): # interim node

            post_prune(dictionary[k], threshold)

        else: # leaf node

            if dictionary[k]> threshold:
                pass
            else:
                to_do = 'delete this node'

想发布这个问题,因为我觉得这应该已经解决了很多次。

谢谢你。

PS:我不打算将最终结果用于分类,因此以这种方式(从外观上)修剪是可行的。

4

1 回答 1

1

你可以尝试这样的事情:

def simplify(tree, threshold):
    # simplify tree bottom-up
    for key, child in tree.items():
        if isinstance(child, dict):
            tree[key] = simplify(child, threshold)
    # all child-nodes are leafs and smaller than threshold -> return max
    if all(isinstance(child, str) and float(child) <= threshold 
           for child in tree.values()):
        return max(tree.values(), key=float)
    # else return tree itself
    return tree

例子:

>>> tree= {'Age': {'> 55': '0.4', '18-35': '0', \
                   '< 18': {'Income': {'high': '0', 'low': '0.2'}}, \
                   '36-55': {'Marital_Status': {'single': {'Income': {'high': '0', 'low': '0.1'}}, \
                                                'married': '0.3'}}}}
>>> simplify(tree, 0.2)
{'Age': {'> 55': '0.4', '< 18': '0.2', '18-35': '0', 
         '36-55': {'Marital_Status': {'single': '0.1', 'married': '0.3'}}}}

更新:好像我误解了你的问题:如果它们的总和小于阈值,你希望简化的树保持叶子的总和!您建议的编辑略有偏差。试试这个:

def simplify(tree, threshold):
    # simplify tree bottom-up
    for key, child in tree.items():
        if isinstance(child, dict):
            tree[key] = simplify(child, threshold)
    # all child-nodes are leafs and sum smaller than threshold -> return sum
    if all(isinstance(child, str) for child in tree.values()) \
           and sum(map(float, tree.values())) <= threshold:
        return str(sum(map(float, tree.values())))
    # else return tree itself
    return tree
于 2014-07-11T11:15:25.957 回答