2

这个问题类似于(并且可能是一个简单的扩展)这里链接的问题:

如何将sklearn决策树规则提取为pandas布尔条件?

来自上述链接的解决方案综合如下:

首先,让我们使用关于决策树结构的 scikit 文档来获取有关所构建树的信息:

n_nodes = clf.tree_.node_count
children_left = clf.tree_.children_left
children_right = clf.tree_.children_right
feature = clf.tree_.feature
threshold = clf.tree_.threshold

然后我们定义两个递归函数。第一个将找到从树根开始的路径以创建特定节点(在我们的例子中是所有叶子)。第二个将编写用于使用其创建路径创建节点的特定规则:

def find_path(node_numb, path, x):
        path.append(node_numb)
        if node_numb == x:
            return True
        left = False
        right = False
        if (children_left[node_numb] !=-1):
            left = find_path(children_left[node_numb], path, x)
        if (children_right[node_numb] !=-1):
            right = find_path(children_right[node_numb], path, x)
        if left or right :
            return True
        path.remove(node_numb)
        return False


def get_rule(path, column_names):
    mask = ''
    for index, node in enumerate(path):
        #We check if we are not in the leaf
        if index!=len(path)-1:
            # Do we go under or over the threshold ?
            if (children_left[node] == path[index+1]):
                mask += "(df['{}']<= {}) \t ".format(column_names[feature[node]], threshold[node])
            else:
                mask += "(df['{}']> {}) \t ".format(column_names[feature[node]], threshold[node])
    # We insert the & at the right places
    mask = mask.replace("\t", "&", mask.count("\t") - 1)
    mask = mask.replace("\t", "")
    return mask

最后,我们使用这两个函数首先存储每个叶子的创建路径。然后存储用于创建每个叶子的规则:

Leaves leave_id = clf.apply(X_test)

paths ={} for leaf in np.unique(leave_id):
    path_leaf = []
    find_path(0, path_leaf, leaf)
    paths[leaf] = np.unique(np.sort(path_leaf))

rules = {} for key in paths:
    rules[key] = get_rule(paths[key], pima.columns)

使用您提供的数据,输出为:

rules = {3: "(df['insulin']<= 127.5) & (df['bp']<= 26.450000762939453) & (df['bp']<= 9.100000381469727)  ",  
4: "(df['insulin']<= 127.5) & (df['bp']<= 26.450000762939453) & (df['bp']> 9.100000381469`727)",  
6: "(df['insulin']<= 127.5) & (df['bp']> 26.450000762939453) & (df['skin']<= 27.5)  ",  
7: "(df['insulin']<= 127.5) & (df['bp']> 26.450000762939453 & (df['skin']> 27.5)  ",  
10: "(df['insulin']> 127.5) & (df['bp']<= 28.149999618530273) &(df['insulin']<= 145.5)  ",  
11: "(df['insulin']> 127.5) & (df['bp']<= 28.149999618530273) & (df['insulin']> 145.5)  ",  
13: "(df['insulin']> 127.5) & (df['bp']> 28.149999618530273) & (df['insulin']<= 158.5)  ",  
14: "(df['insulin']> 127.5) & (df['bp']> 28.149999618530273) & (df['insulin']> 158.5)  "}

由于规则是字符串,你不能直接使用 df[rules[3]] 调用它们,你必须像这样使用 eval 函数 df[eval(rules[3])]

上面发布的解决方案非常适合查找每个终止节点的路径。我想知道是否可以以与上述链接(字典/列表格式)完全相同的输出格式存储每个节点(叶子和终止节点)的路径。

谢谢!

4

1 回答 1

0

好的,所以我想出了一个解决我的问题的方法(尽管我不相信它是最好/最有效的方法),它也不是我问题的直接答案(我没有为每个人存储路径节点 - 只需创建一个能够解析存储信息的函数)。它是上述解决方案的第二部分,允许您为要查找的特定节点提取子集数据。

node_id = 3

def datatree_path_summarystats(node_id):
    for k, v in paths.items():
        if node_id in v:
            d = k,v

    ruleskey = d[0]
    numberofsteps = sum(map(lambda x : x<node_id, d[1]))

    for k, v in rules.items():
        if k == ruleskey:
            b = k,v

    stringsubset = b[1]

    datasubset = "&".join(stringsubset.split('&')[:numberofsteps])
    return datasubset

datasubset = datatree_path_summarystats(node_id)

df[eval(datasubset)]

此函数遍历包含您要查找的节点 ID 的路径。然后,它将根据节点数量拆分规则,从而创建基于该特定节点对数据帧进行子集化的逻辑。

于 2019-12-10T23:29:51.847 回答