1

我是 R 新手,我正在尝试使用party:ctree库的分类决策树。一切似乎都很好。我得到了预期的结果和一个很好的描述情节。

现在,如果我想从拟合摘要中提取结果,我必须遍历每个节点并提取信息。幸运的是,这已经由@baydoganm here编写了。我想扩展此代码并将结果写入 adataframe而不是打印它。

可重现的代码:

library(party)
 ct <- ctree(Species ~ ., data = iris)

   traverse <- function(treenode){
        if(treenode$terminal){
           bas=paste(treenode$nodeID,treenode$prediction)
         print(bas) #here the results are printed
         return(0)
                } 

 traverse(treenode$left)
 traverse(treenode$right)
  }

 traverse(ct@tree) #function call

这工作正常,我在控制台上得到输出。现在,如果我想将结果写入数据框,我就会遇到问题。

到目前为止我尝试了什么:尝试使用可变闭包()写入列表。但不知道如何让它工作。

l <- list()
count = 0
traverse1 <- function(treenode,l){

if((treenode$terminal == T)){
    count <<- count + 1
    print(count)
    node = c(treenode$nodeID)
    pred = c(treenode$prediction)
    l[[count]] <- data.frame(node,pred) #write results in the dataframe    
  } 

  traverse1(treenode$left,l)
  traverse1(treenode$right,l)

}
test <- traverse1(ct@tree,l)# function call

我只得到最后一次调用函数的结果,其余为空

4

2 回答 2

2

如果您使用包中的新改进ctree()实现partykit,那么它的fitted组件中包含您需要的所有信息:

library("partykit")
ct <- ctree(Species ~ ., data = iris)
head(fitted(ct))
##   (fitted) (weights) (response)
## 1        2         1     setosa
## 2        2         1     setosa
## 3        2         1     setosa
## 4        2         1     setosa
## 5        2         1     setosa
## 6        2         1     setosa

xtabs()因此,对于分类树,您可以使用(或table())轻松构建响应的绝对频率表。对于回归树,tapply()可以很容易地用于获取均值、中位数等。

在这种情况下,让我们以表格形式查看绝对频率和相对频率:

tab <- xtabs(~ `(fitted)` + `(response)`, data = fitted(ct))
tab
##         (response)
## (fitted) setosa versicolor virginica
##        2     50          0         0
##        5      0         45         1
##        6      0          4         4
##        7      0          1        45
ptab <- prop.table(tab, 1)
ptab
##         (response)
## (fitted)     setosa versicolor  virginica
##        2 1.00000000 0.00000000 0.00000000
##        5 0.00000000 0.97826087 0.02173913
##        6 0.00000000 0.50000000 0.50000000
##        7 0.00000000 0.02173913 0.97826087

获取频率表的另一种途径是tabtable(predict(ct, type = "node"), iris$Species)

如果您想将其中任何一个转换为数据框,则as.data.frame()工作正常(可能加上一些变量的重新标记......):

as.data.frame(ptab)
##    X.fitted. X.response.       Freq
## 1          2      setosa 1.00000000
## 2          5      setosa 0.00000000
## 3          6      setosa 0.00000000
## 4          7      setosa 0.00000000
## 5          2  versicolor 0.00000000
## 6          5  versicolor 0.97826087
## 7          6  versicolor 0.50000000
## 8          7  versicolor 0.02173913
## 9          2   virginica 0.00000000
## 10         5   virginica 0.02173913
## 11         6   virginica 0.50000000
## 12         7   virginica 0.97826087
于 2016-01-25T22:52:59.730 回答
2

聪明的方式:assign()在全局环境中使用写:

require(party) 
ct <- ctree(Species ~ ., data = iris)

tt <- NULL

traverse <- function(treenode){
  if(treenode$terminal){
    bas=paste(treenode$nodeID,treenode$prediction)
    assign("tt", c(tt, bas), envir = .GlobalEnv)
    print(bas) #here the results are printed
    return(0)
  } 

  traverse(treenode$left)
  traverse(treenode$right)
}

traverse(ct@tree) #function call

data.frame(node.id = unlist(lapply(str_split(tt, " "), function(x) x[[1]]))
       , prediction = unlist(lapply(str_split(tt, " "), function(x) x[[2]])))

肮脏的方式:sink()用来保存你的打印输出。

sink(file = "test.csv", append = T)
traverse(ct@tree) #function call
sink()

tt <- read.csv("test.csv", header = F)
于 2016-01-25T16:46:20.563 回答