扩展先前的问题: 更改使用导出 graphviz 创建的决策树图的颜色
我将如何根据主要类(虹膜种类)为树的节点着色,而不是二元区分?这应该需要 iris.target_names(描述类的字符串)和 iris.target(类)的组合。
import pydotplus
from sklearn.datasets import load_iris
from sklearn import tree
import collections
clf = tree.DecisionTreeClassifier(random_state=42)
iris = load_iris()
clf = clf.fit(iris.data, iris.target)
dot_data = tree.export_graphviz(clf, out_file=None,
feature_names=iris.feature_names,
class_names=iris.target_names,
filled=True, rounded=True,
special_characters=True)
graph = pydotplus.graph_from_dot_data(dot_data)
nodes = graph.get_node_list()
edges = graph.get_edge_list()
colors = ('brown', 'forestgreen')
edges = collections.defaultdict(list)
for edge in graph.get_edge_list():
edges[edge.get_source()].append(int(edge.get_destination()))
for edge in edges:
edges[edge].sort()
for i in range(2):
dest = graph.get_node(str(edges[edge][i]))[0]
dest.set_fillcolor(colors[i])
graph.write_png('tree.png')