2

我正在使用partykitR 包构建一棵树,我想知道是否有一种简单有效的方法来确定每个内部节点的深度数。例如,根节点的深度为 0,前两个子节点的深度为 1,下一个子节点的深度为 2,依此类推。这最终将用于计算变量的最小深度。下面是一个非常基本的示例(取自vignette("constparty", package="partykit")):

library("partykit")
library("rpart")
data("Titanic", package = "datasets")
ttnc<-as.data.frame(Titanic)
ttnc <- ttnc[rep(1:nrow(ttnc), ttnc$Freq), 1:4]
names(ttnc)[2] <- "Gender"
rp <- rpart(Survived ~ ., data = ttnc)
ttncTree<-as.party(rp)
plot(ttncTree)

#This is one of my many attempts which does NOT work
internalNodes<-nodeids(ttncTree)[-nodeids(ttncTree, terminal = TRUE)]
depth(ttncTree)-unlist(nodeapply(ttncTree, ids=internalNodes, FUN=function(n){depth(n)}))

在这个例子中,我想输出类似于:

nodeid = 1 2 4 7 
depth  = 0 1 2 1

如果我的问题太具体,我深表歉意。

4

2 回答 2

3

这是一个可能的解决方案,它应该足够有效,因为通常树的节点不超过几十个。我忽略了节点 #1,因为它始终为 0,因此既不计算也不显示它(IMO)

Inters <- nodeids(ttncTree)[-nodeids(ttncTree, terminal = TRUE)][-1]
table(unlist(sapply(Inters, function(x) intersect(Inters, nodeids(ttncTree, from = x)))))
# 2 4 7 
# 1 2 1 
于 2016-02-08T10:09:13.633 回答
0

我最近不得不重新审视这个问题。下面是一个确定每个节点深度的函数。|我根据运行该print.party()函数的垂直线出现的次数来计算深度。

library(stringr)
idDepth <- function(tree) {
  outTree <- capture.output(tree)
  idCount <- 1
  depthValues <- rep(NA, length(tree))
  names(depthValues) <- 1:length(tree)
  for (index in seq_along(outTree)){
    if (grepl("\\[[0-9]+\\]", outTree[index])) {
      depthValues[idCount] <- str_count(outTree[index], "\\|")
      idCount = idCount + 1
    }
  }
  return(depthValues)
}

> idDepth(ttncTree)
1 2 3 4 5 6 7 8 9 
0 1 2 2 3 3 1 2 2

似乎肯定有一个更简单、更快的解决方案,但这比使用该intersect()函数要快。下面是一个大树(大约 1,500 个节点)的计算时间示例

# Compare computation time for large tree #
library(mlbench)
set.seed(470174)
dat <- data.frame(mlbench.friedman1(5000))
rp <- rpart(as.formula(paste0("y ~ ", paste(paste0("x.", 1:10), collapse=" + "))),
            data=dat, control = rpart.control(cp = -1, minsplit=3, maxdepth = 10))
partyTree <- as.party(rp)

> length(partyTree) #Number of splits
[1] 1503
> 
> # Intersect() computation time
> Inters <- nodeids(partyTree)[-nodeids(partyTree, terminal = TRUE)][-1]
> system.time(table(unlist(sapply(Inters, function(x) intersect(Inters, nodeids(partyTree, from = x))))))
   user  system elapsed 
  22.38    0.00   22.44 
> 
> # Proposed computation time
> system.time(idDepth(partyTree))
   user  system elapsed 
   2.38    0.00    2.38
于 2019-05-21T22:00:29.807 回答