2

如果我回答这个问题,我提前道歉,因为我对 R 和一般的统计分析很陌生。

我已经使用该party库生成了一个条件推理树。
当我plot(my_tree, type = "simple")得到这样的结果时:

R树图

当我print(my_tree)得到这样的结果时:

1) SOME_VALUE <= 2.5; criterion = 1, statistic = 1306.478
  2) SOME_VALUE <= -10.5; criterion = 1, statistic = 173.416
    3) SOME_VALUE <= -16; criterion = 1, statistic = 19.385
      4)*  weights = 275 
    3) SOME_VALUE > -16
      5)*  weights = 261 
  2) SOME_VALUE > -10.5
    6) SOME_VALUE <= -2.5; criterion = 1, statistic = 24.094
      7) SOME_VALUE <= -6.5; criterion = 0.974, statistic = 4.989
        8)*  weights = 346 
      7) SOME_VALUE > -6.5
        9)*  weights = 563 
    6) SOME_VALUE > -2.5
      10)*  weights = 442 
1) SOME_VALUE > 2.5
  11) SOME_VALUE <= 10; criterion = 1, statistic = 225.148
    12) SOME_VALUE <= 6.5; criterion = 1, statistic = 18.789
      13)*  weights = 648 
    12) SOME_VALUE > 6.5
      14)*  weights = 473 
  11) SOME_VALUE > 10
    15) SOME_VALUE <= 16; criterion = 1, statistic = 51.729
      16)*  weights = 595 
    15) SOME_VALUE > 16
      17) SOME_VALUE <= 23.5; criterion = 0.997, statistic = 8.931
        18)*  weights = 488 
      17) SOME_VALUE > 23.5
        19)*  weights = 365 

我更喜欢 的输出print,但它似乎缺少y = (0.96, 0.04)值。

理想情况下,我希望我的输出看起来像这样:

1) SOME_VALUE <= 2.5; criterion = 1, statistic = 1306.478
  2) SOME_VALUE <= -10.5; criterion = 1, statistic = 173.416
    3) SOME_VALUE <= -16; criterion = 1, statistic = 19.385
      4)*  weights = 275; y = (0.96, 0.04)
    3) SOME_VALUE > -16
      5)*  weights = 261; y = (0.831, 0.169)
  2) SOME_VALUE > -10.5
...

我该怎么做呢?

4

1 回答 1

3

partykit可以使用包(的继任者)来做到这一点,party但即使在那里也需要一些黑客攻击。原则上,该print()功能可通过内部节点和终端节点等面板功能进行自定义。但即使对于像这样看似简单的任务,它们看起来也不是很好。

由于您似乎使用了具有双变量响应的树,让我们考虑这个简单(尽管不是很有意义)可重现的示例:

library("partykit")
airq <- subset(airquality, !is.na(Ozone))
ct <- ctree(Ozone + Wind ~ ., data = airq)

对于内部节点,假设我们只想显示$info每个节点中随时可用的 p 值。我们可以通过以下方式格式化:

ip <- function(node) formatinfo_node(node,
  prefix = " ",
  FUN = function(info) paste0("[p = ", format.pval(info$p.value), "]")
)

对于终端节点,我们想要显示观察次数(假设没有weights使用)和平均响应。两者都在小表中预先计算,然后通过$id每个节点的 访问:

n <- table(ct$fitted[["(fitted)"]])
m <- aggregate(ct$fitted[["(response)"]], list(ct$fitted[["(fitted)"]]), mean)
m <- apply(m[, -1], 1, function(x) paste(round(x, digits = 3), collapse = ", "))
names(m) <- names(n)

然后面板函数定义为:

tp <- function(node) formatinfo_node(node,
  prefix = ": ",
  FUN = function(info) paste0(
    "n = ", n[as.character(node$id)],
    ", y = (", m[as.character(node$id)], ")"
  )
)

要在print()方法中应用它,我们需要print.party()直接调用,因为当前print.constparty()没有正确传递它。(我们必须在partykit包中解决这个问题。)

print.party(ct, inner_panel = ip, terminal_panel = tp)
## [1] root
## |   [2] Temp <= 82 [p = 0.0044842]
## |   |   [3] Temp <= 77: n = 52, y = (18.615, 11.562)
## |   |   [4] Temp > 77: n = 27, y = (41.815, 9.737)
## |   [5] Temp > 82: n = 37, y = (75.405, 7.565)

希望这与您想要做的很接近,并且应该为您提供进一步修改的模板。

于 2015-10-27T01:24:02.590 回答