我已经在 iris 数据集上训练了相同的模型来研究每种方法的可重复性。当使用 all.equal() 用于使用 recipes 接口训练的模型时,模型之间似乎存在差异,但不是使用公式或 x/y 接口。这个问题似乎是 gbm 特有的(同样的结构适用于 Model = rf 或 lm)。
食谱公式是否有特定于gbm的东西。还是我的电脑设置的。好奇看看其他人是否可以重现错误。
library(plyr)
library(tidyverse)
library(gbm)
library(caret)
library(recipes)
# recipe to be supplied
Recipe.Obj <- recipe(Sepal.Length ~ ., data = iris)
# train control object
TC.Obj <- trainControl("cv", savePredictions = "all", summaryFunction = defaultSummary, returnResamp = "all")
Model = "gbm"
Recipe = Recipe.Obj
TC = TC.Obj
Training.Data.Set = iris
Metric = "RMSE"
# Using a recipe object
set.seed(0)
Model.Obj.1 <- train(Recipe,
method = Model,
data = Training.Data.Set,
trControl = TC,
metric = Metric,
verbose = FALSE,
tuneLength = 3
)
set.seed(0)
Model.Obj.2 <- train(Recipe,
method = Model,
data = Training.Data.Set,
trControl = TC,
metric = Metric,
verbose = FALSE,
tuneLength = 3
)
# does not return equal objects
all.equal(Model.Obj.1, Model.Obj.2)
[1] "Component “results”: Component “RMSE”: Mean relative difference: 0.0006642504"
[2] "Component “results”: Component “Rsquared”: Mean relative difference: 0.0007520043"
[3] "Component “results”: Component “MAE”: Mean relative difference: 0.001153074"
[4] "Component “results”: Component “RMSESD”: Mean relative difference: 0.001743611"
[5] "Component “results”: Component “RsquaredSD”: Mean relative difference: 0.006758813"
[6] "Component “results”: Component “MAESD”: Mean relative difference: 0.006780553"
[7] "Component “pred”: Component “pred”: Mean relative difference: 0.00312338"
[8] "Component “resample”: Component “RMSE”: Mean relative difference: 0.003475617"
[9] "Component “resample”: Component “Rsquared”: Mean relative difference: 0.002615116"
[10] "Component “resample”: Component “MAE”: Mean relative difference: 0.004711215"
[11] "Component “times”: Component “everything”: Mean relative difference: 0.148289"
[12] "Component “times”: Component “final”: Mean relative difference: 0.5"
# Using formula
set.seed(0)
Model.Obj.3 <- train(Sepal.Length ~ .,
method = Model,
data = Training.Data.Set,
trControl = TC,
metric = Metric,
verbose = FALSE,
tuneLength = 3
)
set.seed(0)
Model.Obj.4 <- train(Sepal.Length ~ .,
method = Model,
data = Training.Data.Set,
trControl = TC,
metric = Metric,
verbose = FALSE,
tuneLength = 3
)
#returns equal objects except for times
all.equal(Model.Obj.3, Model.Obj.4)
# Using x/y
set.seed(0)
Model.Obj.5 <- train(Training.Data.Set[,-1],Training.Data.Set[,1],
method = Model,
trControl = TC,
metric = Metric,
verbose = FALSE,
tuneLength = 3
)
set.seed(0)
Model.Obj.6 <- train(Training.Data.Set[,-1], Training.Data.Set[,1],
method = Model,
trControl = TC,
metric = Metric,
verbose = FALSE,
tuneLength = 3
)
#returns equal objects except for times
all.equal(Model.Obj.5, Model.Obj.6)
会话信息:
会话信息()
R version 3.5.2 (2018-12-20)
Platform: x86_64-w64-mingw32/x64 (64-bit)
Running under: Windows 10 x64 (build 18362)
Matrix products: default
locale:
[1] LC_COLLATE=English_Australia.1252 LC_CTYPE=English_Australia.1252 LC_MONETARY=English_Australia.1252 LC_NUMERIC=C
[5] LC_TIME=English_Australia.1252
attached base packages:
[1] stats graphics grDevices utils datasets methods base
other attached packages:
[1] recipes_0.1.11 caret_6.0-86 lattice_0.20-38 gbm_2.1.5 forcats_0.4.0 stringr_1.4.0 dplyr_0.8.4 purrr_0.3.3 readr_1.3.1
[10] tidyr_1.0.2 tibble_2.1.3 ggplot2_3.2.1 tidyverse_1.3.0 plyr_1.8.4
loaded via a namespace (and not attached):
[1] Rcpp_1.0.1 lubridate_1.7.4 prettyunits_1.0.2 ps_1.3.0 class_7.3-14 assertthat_0.2.1
[7] packrat_0.5.0 ipred_0.9-8 foreach_1.4.4 R6_2.4.0 cellranger_1.1.0 backports_1.1.4
[13] stats4_3.5.2 reprex_0.3.0 httr_1.4.1 pillar_1.4.3 rlang_0.4.5 lazyeval_0.2.1
[19] readxl_1.3.1 data.table_1.11.8 rstudioapi_0.11 callr_3.4.3 rpart_4.1-13 Matrix_1.2-15
[25] splines_3.5.2 gower_0.2.0 munsell_0.5.0 broom_0.5.4 compiler_3.5.2 modelr_0.1.6
[31] pkgconfig_2.0.2 pkgbuild_1.0.6.9000 nnet_7.3-12 tidyselect_0.2.5 prodlim_2018.04.18 gridExtra_2.3
[37] codetools_0.2-15 fansi_0.4.0 crayon_1.3.4 dbplyr_1.4.2 withr_2.1.2 ModelMetrics_1.2.2.2
[43] MASS_7.3-51.5 grid_3.5.2 nlme_3.1-137 jsonlite_1.6.1 gtable_0.2.0 lifecycle_0.1.0
[49] DBI_1.0.0 magrittr_1.5 pROC_1.13.0 scales_1.0.0 cli_2.0.2 stringi_1.3.1
[55] reshape2_1.4.3 fs_1.3.1 timeDate_3043.102 xml2_1.2.2 generics_0.0.2 vctrs_0.2.3
[61] lava_1.6.5 iterators_1.0.10 tools_3.5.2 glue_1.4.0 hms_0.5.3 processx_3.4.1
[67] survival_2.43-3 colorspace_1.4-0 rvest_0.3.5 haven_2.2.0