使用 ncfirth 的链接,我能够修改那里的代码,使其适合我的问题:
from sklearn.tree._tree import TREE_LEAF
def is_leaf(inner_tree, index):
# Check whether node is leaf node
return (inner_tree.children_left[index] == TREE_LEAF and
inner_tree.children_right[index] == TREE_LEAF)
def prune_index(inner_tree, decisions, index=0):
# Start pruning from the bottom - if we start from the top, we might miss
# nodes that become leaves during pruning.
# Do not use this directly - use prune_duplicate_leaves instead.
if not is_leaf(inner_tree, inner_tree.children_left[index]):
prune_index(inner_tree, decisions, inner_tree.children_left[index])
if not is_leaf(inner_tree, inner_tree.children_right[index]):
prune_index(inner_tree, decisions, inner_tree.children_right[index])
# Prune children if both children are leaves now and make the same decision:
if (is_leaf(inner_tree, inner_tree.children_left[index]) and
is_leaf(inner_tree, inner_tree.children_right[index]) and
(decisions[index] == decisions[inner_tree.children_left[index]]) and
(decisions[index] == decisions[inner_tree.children_right[index]])):
# turn node into a leaf by "unlinking" its children
inner_tree.children_left[index] = TREE_LEAF
inner_tree.children_right[index] = TREE_LEAF
##print("Pruned {}".format(index))
def prune_duplicate_leaves(mdl):
# Remove leaves if both
decisions = mdl.tree_.value.argmax(axis=2).flatten().tolist() # Decision for each node
prune_index(mdl.tree_, decisions)
在 DecisionTreeClassifier clf 上使用它:
prune_duplicate_leaves(clf)
编辑:修复了更复杂树的错误