5

说我有

head(kyphosis)
inTrain <- sample(1:nrow(kyphosis), 45, replace = F)
TRAIN_KYPHOSIS <- kyphosis[inTrain,]
TEST_KYPHOSIS <- kyphosis[-inTrain,]

(kyph_tree <- rpart(Number ~ ., data = TRAIN_KYPHOSIS))

如何从拟合对象中获取每个观察的终端节点TEST_KYPHOSIS

如何从每个测试观察映射到的终端节点获取摘要,例如偏差和预测值?

4

2 回答 2

8

rpart实际上有这个功能但它没有暴露(奇怪的是,这是一个相当明显的要求)。

predict_nodes <-
    function (object, newdata, na.action = na.pass) {
        where <-
            if (missing(newdata)) 
                object$where
            else {
                if (is.null(attr(newdata, "terms"))) {
                    Terms <- delete.response(object$terms)
                    newdata <- model.frame(Terms, newdata, na.action = na.action, 
                                           xlev = attr(object, "xlevels"))
                    if (!is.null(cl <- attr(Terms, "dataClasses"))) 
                        .checkMFClasses(cl, newdata, TRUE)
                }
                rpart:::pred.rpart(object, rpart:::rpart.matrix(newdata))
            }
        as.integer(row.names(object$frame))[where]
    }

接着:

> predict_nodes(kyph_tree, TEST_KYPHOSIS)
 [1] 5 3 4 3 3 5 5 3 3 3 3 5 5 4 3 5 4 3 3 3 3 4 3 4 4 5 5 3 4 4 3 5 3 5 5 5
于 2015-05-06T21:51:32.763 回答
5

一种选择是将rpart对象转换partypartykit包中的类对象。这提供了一个用于处理递归聚会的通用工具包。转换很简单:

library("partykit")
(kyph_party <- as.party(kyph_tree))

Model formula:
Number ~ Kyphosis + Age + Start

Fitted party:
[1] root
|   [2] Start >= 15.5: 2.933 (n = 15, err = 10.9)
|   [3] Start < 15.5
|   |   [4] Age >= 112.5: 3.714 (n = 14, err = 18.9)
|   |   [5] Age < 112.5: 5.125 (n = 16, err = 29.8)

Number of inner nodes:    2
Number of terminal nodes: 3

(为了获得精确的可重复性,请set.seed(1)在运行我的代码之前运行您问题中的代码。)

对于此类的对象,对于plot()predict()fitted()等,有一些更灵活的方法。例如,plot(kyph_party)产生比默认值更多的信息显示plot(kyph_tree)。该fitted()方法提取了一个两列data.frame,其中包含拟合的节点数和在训练数据上观察到的响应。

kyph_fit <- fitted(kyph_party)
head(kyph_fit, 3)

  (fitted) (response)
1        5          6
2        2          2
3        4          3

有了这个,您可以轻松计算您感兴趣的任何数量,例如,每个节点内的均值、中位数或残差平方和。

tapply(kyph_fit[,2], kyph_fit[,1], mean)

       2        4        5 
2.933333 3.714286 5.125000 

tapply(kyph_fit[,2], kyph_fit[,1], median)

2 4 5 
3 4 5 

tapply(kyph_fit[,2], kyph_fit[,1], function(x) sum((x - mean(x))^2))

       2        4        5 
10.93333 18.85714 29.75000 

tapply()您可以使用您选择的任何其他函数来计算分组统计表,而不是简单的。

TEST_KYPHOSIS现在要了解从测试数据到树中哪个节点的哪个观察,您可以简单地使用以下predict(..., type = "node")方法:

kyph_pred <- predict(kyph_party, newdata = TEST_KYPHOSIS, type = "node")
head(kyph_pred)

 2  3  4  6  7 10 
 4  4  5  2  2  5 
于 2015-03-29T11:44:51.387 回答