我在 R 中使用 Party 包创建了决策树。我试图获得具有最大值的路由/分支。
它可以是来自箱线图的平均值
它可以是来自二叉树的概率值(来源:rdatamining.com)
这实际上可以很容易地完成,尽管您对回归树的最大值的定义很清楚,但对于分类树来说却不是很清楚,因为在每个节点中,不同级别可以有自己的最大值
无论哪种方式,这是一个非常简单的辅助函数,它将为您返回每种树的预测
GetPredicts <- function(ct){
f <- function(ct, i) nodes(ct, i)[[1]]$prediction
Terminals <- unique(where(ct))
Predictions <- sapply(Terminals, f, ct = ct)
if(is.matrix(Predictions)){
colnames(Predictions) <- Terminals
return(Predictions)
} else {
return(setNames(Predictions, Terminals))
}
}
现在幸运的是,您已经从 的示例中获取了您的树?ctree
,因此我们可以对其进行测试(下一次,请提供您自己使用的代码)
回归树(你的第一棵树)
## load the package and create the tree
library(party)
airq <- subset(airquality, !is.na(Ozone))
airct <- ctree(Ozone ~ ., data = airq,
controls = ctree_control(maxsurrogate = 3))
plot(airct)
现在,测试功能
res <- GetPredicts(airct)
res
# 5 3 6 9 8
# 18.47917 55.60000 31.14286 48.71429 81.63333
所以我们得到了每个终端节点的预测。您可以从这里轻松进行which.max(res)
(我将留给您决定)
分类树(你的第二棵树)
irisct <- ctree(Species ~ .,data = iris)
plot(irisct, type = "simple")
运行函数
res <- GetPredicts(irisct)
res
# 2 5 6 7
# [1,] 1 0.00000000 0.0 0.00000000
# [2,] 0 0.97826087 0.5 0.02173913
# [3,] 0 0.02173913 0.5 0.97826087
现在,输出有点难以阅读,因为每个类都有自己的概率。您可以使用使其更具可读性
row.names(res) <- levels(iris$Species)
res
# 2 5 6 7
# setosa 1 0.00000000 0.0 0.00000000
# versicolor 0 0.97826087 0.5 0.02173913
# virginica 0 0.02173913 0.5 0.97826087
您可以执行以下操作以获得整体最大值
which(res == max(res), arr.ind = TRUE)
# row col
# setosa 1 1
对于列/行最大值,您可以这样做
matrixStats::colMaxs(res)
# [1] 1.0000000 0.9782609 0.5000000 0.9782609
matrixStats::rowMaxs(res)
# [1] 1.0000000 0.9782609 0.9782609
但是,我再一次让你决定如何从这里开始。