我想以编程方式测试从树生成的一条规则。在树中,根和叶(终端节点)之间的路径可以解释为规则。
在 R 中,我们可以使用rpart
包并执行以下操作:(在这篇文章中,我将使用iris
数据集,仅用于示例目的)
library(rpart)
model <- rpart(Species ~ ., data=iris)
有了这两行,我得到了一个名为 的树model
,它的类是rpart.object
(rpart
文档,第 21 页)。这个对象有很多信息,并且支持多种方法。特别是,该对象有一个frame
变量(可以以标准方式访问:model$frame
)(同上)和方法path.rpath
(rpart
文档,第 7 页),它为您提供从根节点到感兴趣节点的路径(node
参数在功能)
变量的包含树row.names
的frame
节点号。该var
列给出了节点中的分裂变量、yval
拟合值和yval2
类概率等信息。
> model$frame
var n wt dev yval complexity ncompete nsurrogate yval2.1 yval2.2 yval2.3 yval2.4 yval2.5 yval2.6 yval2.7
1 Petal.Length 150 150 100 1 0.50 3 3 1.00000000 50.00000000 50.00000000 50.00000000 0.33333333 0.33333333 0.33333333
2 <leaf> 50 50 0 1 0.01 0 0 1.00000000 50.00000000 0.00000000 0.00000000 1.00000000 0.00000000 0.00000000
3 Petal.Width 100 100 50 2 0.44 3 3 2.00000000 0.00000000 50.00000000 50.00000000 0.00000000 0.50000000 0.50000000
6 <leaf> 54 54 5 2 0.00 0 0 2.00000000 0.00000000 49.00000000 5.00000000 0.00000000 0.90740741 0.09259259
7 <leaf> 46 46 1 3 0.01 0 0 3.00000000 0.00000000 1.00000000 45.00000000 0.00000000 0.02173913 0.97826087
但只有列中标记为<leaf>
终端var
节点(叶子)。在这种情况下,节点是 2、6 和 7。
如上所述,您可以使用path.rpart
提取规则的方法(此技术在 rattle
包和文章Sharma Credit Score中使用,如下所示:
此外,该模型将预测值的值保留在
predicted.levels <- attr(model, "ylevels")
该值与数据集中的列相对yval
应model$frame
。
对于节点号为 7(行号为 5)的叶子,预测值为
> ylevels[model$frame[5, ]$yval]
[1] "virginica"
规则是
> rule <- path.rpart(model, nodes = 7)
node number: 7
root
Petal.Length>=2.45
Petal.Width>=1.75
因此,该规则可以理解为
If Petal.Length >= 2.45 AND Petal.Width >= 1.75 THEN Species = Virginica
我知道我可以测试(在测试数据集中,我将再次使用 iris 数据集)这条规则有多少真阳性,对新数据集进行子集如下
> hits <- subset(iris, Petal.Length >= 2.45 & Petal.Width >= 1.75)
然后计算混淆矩阵
> table(hits$Species, hits$Species == "virginica")
FALSE TRUE
setosa 0 0
versicolor 1 0
virginica 0 45
(注:我使用了与测试相同的 iris 数据集)
我如何以编程方式评估规则?我可以从规则中提取条件如下
> unlist(rule, use.names = FALSE)[-1]
[1] "Petal.Length>=2.45" "Petal.Width>=1.75"
但是,我怎么能从这里继续呢?我无法使用该subset
功能
提前致谢
注意: 为了更清晰,这个问题已经过大量编辑