4

所以我使用 rpart 包创建了一个树模型,我发现了一个有趣的规则,并想知道是否有一种简单的方法可以查看该数据框中的哪些观察结果通过了该规则。

使用 path.rpart 查找它沿树所走的路径似乎非常繁琐,然后手动将这些过滤器输入到数据框中以查找它们。有没有一种方法可以传递树和/或节点以及数据帧并返回该帧中在该节点处结束的所有元素?

4

2 回答 2

9

我修改了代码path.rpart以返回属于特定节点的数据子集,而不是返回有关该节点的信息。它可以通过单击绘图或像path.rpart函数一样传递节点来工作。这是代码

subset.rpart <- function (tree, df, nodes) {
    if (!inherits(tree, "rpart")) 
        stop("Not a legitimate \"rpart\" object")
    stopifnot(nrow(df)==length(tree$where))
    frame <- tree$frame
    n <- row.names(frame)
    node <- as.numeric(n)

    if (missing(nodes)) {
        xy <- rpart:::rpartco(tree)
        i <- identify(xy, n = 1L, plot = FALSE)
        if(i> 0L) {
             return( df[tree$where==i, ] )
        } else {
            return(df[0,])
        }
    }
    else {
        if (length(nodes <- rpart:::node.match(nodes, node)) == 0L) 
            return(df[0,])
        return ( df[tree$where %in% as.numeric(nodes), ] )
    }
}

我将在包中的一些示例数据上使用它

fit <- rpart(Kyphosis ~ Age + Number + Start, data = kyphosis)
plot(fit)
text(fit)

rpart 树图

然后要找到特定节点的观察结果,运行

subset.rpart(fit, kyphosis)

并单击绘图上的一个节点。完成后,将返回该节点的所有观察结果。您必须使用data.frame用于建模的相同功能才能正常工作。除了单击一个点,您还可以传入一个您发现的节点名称path.rpart

# path.rpart(fit)  
#  node number: 10  ---> looks interesting
#    root
#    Start>=8.5
#    Start< 14.5
#    Age< 55

subset.rpart(fit, kyphosis, 10)
#    Kyphosis Age Number Start
# 14   absent   1      4    12
# 20   absent  27      4     9
# 26   absent   9      5    13
# 37   absent   1      3     9
# 39   absent  20      6     9
# 42   absent  35      3    13
# 57   absent   2      3    13
# 59   absent  51      7     9
# 66   absent  17      4    10
# 69   absent  18      4    11
# 78   absent  26      7    13
# 81   absent  36      4    13
于 2014-06-04T03:45:04.870 回答
1
#' subset of rpart node: return logical index
#' @param tree rpart model
#' @param node which node/leaf?
#' @export
subset_rpart <- function (tree, node) {
  nodes = as.numeric(rownames(tree$frame))
  nodes = log(nodes, 2)
  lower = log(node, 2)
  upper = log(node + 1, 2)
  a = floor(lower)
  lower_ = lower - a
  upper_  = upper - a
  nodes_ = nodes %% 1
  w = which(((nodes_ >= lower_ & nodes_ < upper_) | (nodes_ + 1 < upper_)) & nodes >= lower)
  tree$where %in% w
}



#' subset df by subset_rpart
#' @param tree rpart model
#' @param node node number
#' @param df df
#' @export
subset.rpart = function(tree, node, df){
  df[subset_rpart(tree, node), ]
}
于 2017-08-29T08:56:00.360 回答