0

在 IRIS 数据集中。

library("party")
set.seed(1)
x <- ctree(Species ~ ., data=iris)

print(x)

 #    Conditional inference tree with 4 terminal nodes
 # 
 # Response:  Species 
 # Inputs:  Sepal.Length, Sepal.Width, Petal.Length, Petal.Width 
 # Number of observations:  150 
 # 
 # 1) Petal.Length <= 1.9; criterion = 1, statistic = 140.264
 #   2)*  weights = 50 
 # 1) Petal.Length > 1.9
 #   3) Petal.Width <= 1.7; criterion = 1, statistic = 67.894
 #     4) Petal.Length <= 4.8; criterion = 0.999, statistic = 13.865
 #       5)*  weights = 46 
 #     4) Petal.Length > 4.8
 #       6)*  weights = 8 
 #   3) Petal.Width > 1.7
 #     7)*  weights = 46 

 plot(x, type="simple")

ctree 图

在上图中,要打印的结果如下。(仅终端节点的结果就足够了)

"Petal.Length <= 1.9"; criterion = 1, statistic = 140.264, weights = 50 
"Petal.Length > 1.9, Petal.Width <= 1.7, Petal.Length <= 4.8", criterion = 0.999, statistic = 13.865, weights = 46
"Petal.Length > 1.9, Petal.Width <= 1.7, Petal.Length > 4.8", criterion = 0.999, statistic = 13.865, weights = 8
"Petal.Length > 1.9, Petal.Width > 1.7", criterion = 0.999, statistic = 13.865, weights = 46
4

1 回答 1

5

我在iris这里使用数据集进行分类,因为您没有提供可重现的示例:

library("party")
set.seed(1)
x <- ctree(Species ~ ., data=iris)

print(x)
#    Conditional inference tree with 4 terminal nodes
# 
# Response:  Species 
# Inputs:  Sepal.Length, Sepal.Width, Petal.Length, Petal.Width 
# Number of observations:  150 
# 
# 1) Petal.Length <= 1.9; criterion = 1, statistic = 140.264
#   2)*  weights = 50 
# 1) Petal.Length > 1.9
#   3) Petal.Width <= 1.7; criterion = 1, statistic = 67.894
#     4) Petal.Length <= 4.8; criterion = 0.999, statistic = 13.865
#       5)*  weights = 46 
#     4) Petal.Length > 4.8
#       6)*  weights = 8 
#   3) Petal.Width > 1.7
#     7)*  weights = 46 

plot(x, type="simple")

ctree 图

在下一步中,我只是覆盖相关的print-methods 来格式化输出;缺少该prediction列,因为我不明白您的确切意思:

print.SplittingNode <- function(x, ...) {
    cat(sprintf("%d, %d, %d; weight = %d\n",
                x$nodeID, x$left$nodeID, x$right$nodeID, sum(x$weight)))
    print(x$left)
    print(x$right)
}
print.TerminalNode <- function(x, ...) {
    cat(sprintf("%d, NA, NA; weight = %d\n", x$nodeID, sum(x$weight)))
}

print(x@tree)
# 1, 2, 3; weight = 150
# 2, NA, NA; weight = 50
# 3, 4, 7; weight = 100
# 4, 5, 6; weight = 54
# 5, NA, NA; weight = 46
# 6, NA, NA; weight = 8
# 7, NA, NA; weight = 46

更新:这是一个根据您的要求格式化树的递归函数:

format_tree <- function(x, res=NULL) {
  if (!x$terminal) {
    ## left branch
    res_l <- c(res,
               sprintf("%s<=%.3f", x$psplit$variableName, x$psplit$splitpoint))
    if (x$left$terminal) {
      format_tree(x$left,
                  c(res_l,
                    sprintf("criterion=%.3f, statistic=%.3f",
                            x$criterion$maxcriterion,
                            max(x$criterion$statistic))))
    } else {
      format_tree(x$left, res_l)
    }
    ## right branch
    res_r <- c(res,
               sprintf("%s>%.3f", x$psplit$variableName, x$psplit$splitpoint))
    if (x$right$terminal) {
      format_tree(x$right,
                  c(res_r,
                    sprintf("criterion=%.3f, statistic=%.3f",
                            x$criterion$maxcriterion,
                            max(x$criterion$statistic))))

    } else {
      format_tree(x$right, res_r)
    }
  } else {
    cat(paste(res, collapse=", "), ", weights=", sum(x$weights),
        "\n", sep="")
  }
  invisible(NULL)
}
format_tree(x@tree)

# Petal.Length<=1.900, criterion=1.000, statistic=140.264, weights=50
# Petal.Length>1.900, Petal.Width<=1.700, Petal.Length<=4.800, criterion=0.999, statistic=13.865, weights=46
# Petal.Length>1.900, Petal.Width<=1.700, Petal.Length>4.800, criterion=0.999, statistic=13.865, weights=8
# Petal.Length>1.900, Petal.Width>1.700, criterion=1.000, statistic=67.894, weights=46
于 2013-11-25T14:00:46.373 回答