目前使用该tidymodels
框架并努力理解随机森林(使用ranger
)和增强回归树(使用xgboost
)的训练模型输出的一些差异。
我的问题是关于使用xgboost
- 特别是如何访问预测/拟合正在训练的基础模型的训练数据(不使用predict
)。
为了澄清我的意思,在拟合随机森林模型时,我可以通过rf_fit
两种方式探索拟合模型(在下面的 reprex 中)及其对训练数据的预测。
- 使用
predict()
-调用predict(rf_fit, cells, type = "prob"
。(方法一)。 rf_fit
直接从 ( ) 获取预测rf_fit$fit$predictions
(方法 2)。
由于此处已阐明的原因,这些导致不同的预测。
在这种情况下,我对提升回归树和我的对象的等价物rf_fit$fit$predictions
(即方法 2)特别感兴趣。xgb_fit
我的问题有两个:
xgb_fit
训练模型的预测在哪里?(即rf_fit$fit$predictions
我们在随机森林模型中得到的等价物在哪里)?或者,我需要添加什么才能输出这些预测?- 如果以上是可能的,我应该如何解释这些预测?它们与 call 不同
predict
吗?如果是这样,它们代表什么(我收集的袋外估计对于提升回归树来说并不重要)?
(基本上,我想要training_logloss
在 1000 次迭代时产生错误的模型的预测xgb_fit$fit$evaluation_log
)。
# Load required libraries
library(tidymodels); library(modeldata)
#> Registered S3 method overwritten by 'tune':
#> method from
#> required_pkgs.model_spec parsnip
# Set seed
set.seed(123)
# Load in data
data(cells, package = "modeldata")
# Define Random Forest Model
rf_mod <- rand_forest(trees = 1000) %>%
set_mode("classification") %>%
set_engine("ranger")
# Define BRT Model
xgb_mod <- boost_tree(trees = 1000) %>%
set_mode("classification") %>%
set_engine("xgboost",
objective = 'binary:logistic',
eval_metric = 'logloss')
# Fit the models to training data
rf_fit <- rf_mod %>%
fit(class ~ ., data = cells)
xgb_fit <- xgb_mod %>%
fit(class ~ ., data = cells)
xgb_fit$fit$evaluation_log
#> iter training_logloss
#> 1: 1 0.542353
#> 2: 2 0.443275
#> 3: 3 0.382232
#> 4: 4 0.333377
#> 5: 5 0.303415
#> ---
#> 996: 996 0.001918
#> 997: 997 0.001917
#> 998: 998 0.001917
#> 999: 999 0.001916
#> 1000: 1000 0.001915
# Examine output predictions on training data for RANDOM FOREST Model
rf_whole <- predict(rf_fit, cells, type = "prob") # predictions based on whole fitted model
rf_oob <- head(rf_fit$fit$predictions) # predictions based on out of bag samples
## these are different to each other as we would expect
rf_whole$.pred_PS[1]
#> [1] 0.9229111
rf_oob[1, "PS"]
#> PS
#> 0.8503902
# Examine output predictions on training data for BOOSTED REGRESSION TREE Model
xgb_whole <- predict(xgb_fit, cells, type = "prob")
reprex
#> Error in eval(expr, envir, enclos): object 'reprex' not found
由reprex 包于 2021-10-05 创建(v2.0.1)