这真的很好,我不知道这是一个选择。
所有的工作似乎都是获取每个节点上使用的原始数据的子集。这对于终端节点来说很容易,但我没有找到一种直接的方法来识别每个节点中使用的数据行,而不仅仅是叶子。如果有人知道更简单的方法,我很想听听。
library('rpart.plot')
set.seed(1)
mydb<-data.frame(results=rnorm(1000,0,1),expo=runif(1000),var1=sample(LETTERS[1:4],1000,replace=T),
var2=sample(LETTERS[5:6],1000,replace=T),var3=sample(LETTERS[20:25], 1000,replace=T))
mytree<-rpart(results~var1+var2+var3,data=mydb,cp=0)
pfit<- prune(mytree, cp=mytree$cptable[4,"CP"])
rpart.plot(pfit)
定义你的新函数,它接受x
拟合的结果rpart
(我没有研究其他参数,但小插图应该会有所帮助)。
对于每一行,x$frame
我们都需要获取用于计算汇总统计的数据。不幸的是,x$where
它只告诉我们每个观察点所在的终端节点。所以对于每个节点号,我们subset.rpart
用来获取底层数据,并用它做任何你想做的事情
f <- function(x, labs, digits, varlen) {
nodes <- as.integer(rownames(x$frame))
z <- sapply(nodes, function(y) {
data <- subset.rpart(x, y)
c(mean = mean(data$expo), nrow(data), nrow(data) / length(x$where) * 100)
})
sprintf('Mean expo: %.2f\nn=%.0f (%.0f%%)', z[1, ], z[2, ], z[3, ])
}
prp(pfit, type=1, extra=100, fallen.leaves=FALSE,
shadow.col="darkgray", box.col=rgb(0.8,0.9,0.8),
node.fun = f)
这项工作是通过subset.rpart
它接受一个节点号并返回节点上data
使用的子集来完成的。
subset.rpart <- function(tree, node = 1L) {
## returns subset of tree$call$data used on any node
data <- eval(tree$call$data, parent.frame(1L))
wh <- sapply(as.integer(rownames(tree$frame)), parent)
wh <- unique(unlist(wh[sapply(wh, function(x) node %in% x)]))
data[rownames(tree$frame)[tree$where] %in% wh[wh >= node], ]
}
parent <- function(x) {
## returns vector of parent nodes
if (x[1] != 1)
c(Recall(if (x %% 2 == 0L) x / 2 else (x - 1) / 2), x) else x
}
测试
## tests
dim(subset.rpart(pfit, 1)) == dim(mydb)
# [1] TRUE TRUE
## terminal nodes
nodes <- as.integer(rownames(pfit$frame[pfit$frame$var %in% '<leaf>', ]))
sum(sapply(nodes, function(x) nrow(subset.rpart(pfit, x)))) == nrow(mydb)
# [1] TRUE