我正在尝试使用通用数据集重现我的数据集出现的错误。如果我遗漏了什么,请纠正我。
使用拟合Classification
树后library(party)
,我试图在每个节点上获取树的拆分条件。我设法编写了一个我认为运行良好的代码,直到我发现了一个错误。谁能帮我解决它?
我的代码:
require(party)
iris$Petal.Width <- as.factor(iris$Petal.Width)#imp to convert to factorial
(ct <- ctree(Species ~ ., data = iris))
plot(ct)
#print(ct)
a <- ct #convert it to s4 object
t <- a@tree
#recursive function to traverse the tree and get the splitting conditions
recurse_tree <- function(tree,ret_list=list(),sub_list=list()){
if(!tree$terminal){
sub_list$assign <-list(tree$psplit$splitpoint,tree$psplit$variableName,class(tree$psplit))
names(sub_list)[which(names(sub_list)=="assign")] <- paste("node",tree$nodeID,sep="")
ret_list <- recurse_tree(tree$left, ret_list, sub_list)
ret_list <- recurse_tree(tree$right, ret_list, sub_list)
}
if(tree$terminal){
ret_list$assign <- c(sub_list, tree$prediction)
names(ret_list)[which(names(ret_list)=="assign")] <- paste("node",tree$nodeID,sep="")
return(ret_list)
}
return(ret_list)
}
result <- recurse_tree(t) #call to the functions
现在,结果给了我所有节点的列表以及拆分条件和预测(我假设)。但是,当我检查拆分条件时
- Node5上的预期输出:{1.1, 1.2, 1.6, 1.7 }
# from printing the tree print(ct), I get this
我从我的函数得到的Node5 的输出: {"1" , "1.3" ,"1.4" ,"1.5" } 这基本上是 Node6 的拆分条件,这是错误的。我是怎么得到这个的?
z <- result[2] #I know node5 is second in the list
z <- unlist(z,recursive = F,use.names = T) #unlist levels(z[[3]][[1]]) [which((z[[3]][[1]])==0)] #to find levels of corresponding values
我怀疑的是,我的函数(recurse_tree
)总是给我右终端节点而不是左节点的分裂条件。任何帮助将不胜感激。