3

我已经在 iris 数据集上训练了相同的模型来研究每种方法的可重复性。当使用 all.equal() 用于使用 recipes 接口训练的模型时,模型之间似乎存在差异,但不是使用公式或 x/y 接口。这个问题似乎是 gbm 特有的(同样的结构适用于 Model = rf 或 lm)。

食谱公式是否有特定于gbm的东西。还是我的电脑设置的。好奇看看其他人是否可以重现错误。

library(plyr)
library(tidyverse)
library(gbm)
library(caret)
library(recipes)

# recipe to be supplied
Recipe.Obj <- recipe(Sepal.Length ~ ., data = iris)

# train control object
TC.Obj <- trainControl("cv", savePredictions = "all", summaryFunction = defaultSummary, returnResamp = "all")

Model = "gbm"
Recipe = Recipe.Obj
TC = TC.Obj
Training.Data.Set = iris
Metric = "RMSE"

# Using a recipe object

set.seed(0)
Model.Obj.1 <- train(Recipe,
                     method = Model,
                     data = Training.Data.Set,
                     trControl = TC,
                     metric = Metric,
                     verbose = FALSE,
                     tuneLength = 3
)

set.seed(0)
Model.Obj.2 <- train(Recipe,
                     method = Model,
                     data = Training.Data.Set,
                     trControl = TC,
                     metric = Metric,
                     verbose = FALSE,
                     tuneLength = 3
)

# does not return equal objects
all.equal(Model.Obj.1, Model.Obj.2)

[1] "Component “results”: Component “RMSE”: Mean relative difference: 0.0006642504"     
 [2] "Component “results”: Component “Rsquared”: Mean relative difference: 0.0007520043" 
 [3] "Component “results”: Component “MAE”: Mean relative difference: 0.001153074"       
 [4] "Component “results”: Component “RMSESD”: Mean relative difference: 0.001743611"    
 [5] "Component “results”: Component “RsquaredSD”: Mean relative difference: 0.006758813"
 [6] "Component “results”: Component “MAESD”: Mean relative difference: 0.006780553"     
 [7] "Component “pred”: Component “pred”: Mean relative difference: 0.00312338"          
 [8] "Component “resample”: Component “RMSE”: Mean relative difference: 0.003475617"     
 [9] "Component “resample”: Component “Rsquared”: Mean relative difference: 0.002615116" 
[10] "Component “resample”: Component “MAE”: Mean relative difference: 0.004711215"      
[11] "Component “times”: Component “everything”: Mean relative difference: 0.148289"     
[12] "Component “times”: Component “final”: Mean relative difference: 0.5"

# Using formula

set.seed(0)
Model.Obj.3 <- train(Sepal.Length ~ .,
                     method = Model,
                     data = Training.Data.Set,
                     trControl = TC,
                     metric = Metric,
                     verbose = FALSE,
                     tuneLength = 3
)

set.seed(0)
Model.Obj.4 <- train(Sepal.Length ~ .,
                     method = Model,
                     data = Training.Data.Set,
                     trControl = TC,
                     metric = Metric,
                     verbose = FALSE,
                     tuneLength = 3
)

#returns equal objects except for times
all.equal(Model.Obj.3, Model.Obj.4)

# Using x/y

set.seed(0)
Model.Obj.5 <- train(Training.Data.Set[,-1],Training.Data.Set[,1],
                     method = Model,
                     trControl = TC,
                     metric = Metric,
                     verbose = FALSE,
                     tuneLength = 3
)

set.seed(0)
Model.Obj.6 <- train(Training.Data.Set[,-1], Training.Data.Set[,1],
                     method = Model,
                     trControl = TC,
                     metric = Metric,
                     verbose = FALSE,
                     tuneLength = 3
)
#returns equal objects except for times
all.equal(Model.Obj.5, Model.Obj.6)

会话信息:

会话信息()

R version 3.5.2 (2018-12-20)
Platform: x86_64-w64-mingw32/x64 (64-bit)
Running under: Windows 10 x64 (build 18362)

Matrix products: default

locale:
[1] LC_COLLATE=English_Australia.1252  LC_CTYPE=English_Australia.1252    LC_MONETARY=English_Australia.1252 LC_NUMERIC=C                      
[5] LC_TIME=English_Australia.1252    

attached base packages:
[1] stats     graphics  grDevices utils     datasets  methods   base     

other attached packages:
 [1] recipes_0.1.11  caret_6.0-86    lattice_0.20-38 gbm_2.1.5       forcats_0.4.0   stringr_1.4.0   dplyr_0.8.4     purrr_0.3.3     readr_1.3.1    
[10] tidyr_1.0.2     tibble_2.1.3    ggplot2_3.2.1   tidyverse_1.3.0 plyr_1.8.4     

loaded via a namespace (and not attached):
 [1] Rcpp_1.0.1           lubridate_1.7.4      prettyunits_1.0.2    ps_1.3.0             class_7.3-14         assertthat_0.2.1    
 [7] packrat_0.5.0        ipred_0.9-8          foreach_1.4.4        R6_2.4.0             cellranger_1.1.0     backports_1.1.4     
[13] stats4_3.5.2         reprex_0.3.0         httr_1.4.1           pillar_1.4.3         rlang_0.4.5          lazyeval_0.2.1      
[19] readxl_1.3.1         data.table_1.11.8    rstudioapi_0.11      callr_3.4.3          rpart_4.1-13         Matrix_1.2-15       
[25] splines_3.5.2        gower_0.2.0          munsell_0.5.0        broom_0.5.4          compiler_3.5.2       modelr_0.1.6        
[31] pkgconfig_2.0.2      pkgbuild_1.0.6.9000  nnet_7.3-12          tidyselect_0.2.5     prodlim_2018.04.18   gridExtra_2.3       
[37] codetools_0.2-15     fansi_0.4.0          crayon_1.3.4         dbplyr_1.4.2         withr_2.1.2          ModelMetrics_1.2.2.2
[43] MASS_7.3-51.5        grid_3.5.2           nlme_3.1-137         jsonlite_1.6.1       gtable_0.2.0         lifecycle_0.1.0     
[49] DBI_1.0.0            magrittr_1.5         pROC_1.13.0          scales_1.0.0         cli_2.0.2            stringi_1.3.1       
[55] reshape2_1.4.3       fs_1.3.1             timeDate_3043.102    xml2_1.2.2           generics_0.0.2       vctrs_0.2.3         
[61] lava_1.6.5           iterators_1.0.10     tools_3.5.2          glue_1.4.0           hms_0.5.3            processx_3.4.1      
[67] survival_2.43-3      colorspace_1.4-0     rvest_0.3.5          haven_2.2.0         
4

0 回答 0