如果您使用其中一个tune_*
函数为您的模型找到了最佳参数,然后使用这些参数完成了您的工作流程,那么下一步就是在整个训练集上再次训练或拟合该工作流程。让我们来看一个例子。
library(tidymodels)
#> ── Attaching packages ────────────────────────────────────── tidymodels 0.1.1 ──
#> ✓ broom 0.7.0 ✓ recipes 0.1.13
#> ✓ dials 0.0.8 ✓ rsample 0.0.7
#> ✓ dplyr 1.0.0 ✓ tibble 3.0.3
#> ✓ ggplot2 3.3.2 ✓ tidyr 1.1.0
#> ✓ infer 0.5.3 ✓ tune 0.1.1
#> ✓ modeldata 0.0.2 ✓ workflows 0.1.2
#> ✓ parsnip 0.1.2 ✓ yardstick 0.0.7
#> ✓ purrr 0.3.4
#> ── Conflicts ───────────────────────────────────────── tidymodels_conflicts() ──
#> x purrr::discard() masks scales::discard()
#> x dplyr::filter() masks stats::filter()
#> x dplyr::lag() masks stats::lag()
#> x recipes::step() masks stats::step()
## pretend this is your training data
data("hpc_data")
xgb_spec <- boost_tree(
trees = 1000,
tree_depth = tune(),
min_n = tune()
) %>%
set_engine("xgboost") %>%
set_mode("classification")
hpc_folds <- vfold_cv(hpc_data, strata = class)
xgb_grid <- grid_latin_hypercube(
tree_depth(),
min_n(),
size = 5
)
xgb_wf <- workflow() %>%
add_formula(class ~ .) %>%
add_model(xgb_spec)
xgb_wf
#> ══ Workflow ════════════════════════════════════════════════════════════════════
#> Preprocessor: Formula
#> Model: boost_tree()
#>
#> ── Preprocessor ────────────────────────────────────────────────────────────────
#> class ~ .
#>
#> ── Model ───────────────────────────────────────────────────────────────────────
#> Boosted Tree Model Specification (classification)
#>
#> Main Arguments:
#> trees = 1000
#> min_n = tune()
#> tree_depth = tune()
#>
#> Computational engine: xgboost
doParallel::registerDoParallel()
set.seed(123)
xgb_res <- tune_grid(
xgb_wf,
resamples = hpc_folds,
grid = xgb_grid
)
xgb_res
#> # Tuning results
#> # 10-fold cross-validation using stratification
#> # A tibble: 10 x 4
#> splits id .metrics .notes
#> <list> <chr> <list> <list>
#> 1 <split [3.9K/434]> Fold01 <tibble [10 × 6]> <tibble [0 × 1]>
#> 2 <split [3.9K/434]> Fold02 <tibble [10 × 6]> <tibble [0 × 1]>
#> 3 <split [3.9K/434]> Fold03 <tibble [10 × 6]> <tibble [0 × 1]>
#> 4 <split [3.9K/434]> Fold04 <tibble [10 × 6]> <tibble [0 × 1]>
#> 5 <split [3.9K/434]> Fold05 <tibble [10 × 6]> <tibble [0 × 1]>
#> 6 <split [3.9K/434]> Fold06 <tibble [10 × 6]> <tibble [0 × 1]>
#> 7 <split [3.9K/433]> Fold07 <tibble [10 × 6]> <tibble [0 × 1]>
#> 8 <split [3.9K/432]> Fold08 <tibble [10 × 6]> <tibble [0 × 1]>
#> 9 <split [3.9K/431]> Fold09 <tibble [10 × 6]> <tibble [0 × 1]>
#> 10 <split [3.9K/431]> Fold10 <tibble [10 × 6]> <tibble [0 × 1]>
接下来,让我们完成此工作流程,然后fit()
将其用于训练数据。(调优过程使用了训练数据,但那是为了找到最佳模型参数,而不是训练模型本身。)
trained_wf <- xgb_wf %>%
finalize_workflow(
select_best(xgb_res, "roc_auc")
) %>%
fit(hpc_data)
trained_wf
#> ══ Workflow [trained] ══════════════════════════════════════════════════════════
#> Preprocessor: Formula
#> Model: boost_tree()
#>
#> ── Preprocessor ────────────────────────────────────────────────────────────────
#> class ~ .
#>
#> ── Model ───────────────────────────────────────────────────────────────────────
#> ##### xgb.Booster
#> raw: 2 Mb
#> call:
#> xgboost::xgb.train(params = list(eta = 0.3, max_depth = 3L, gamma = 0,
#> colsample_bytree = 1, min_child_weight = 9L, subsample = 1),
#> data = x, nrounds = 1000, watchlist = wlist, verbose = 0,
#> objective = "multi:softprob", num_class = 4L, nthread = 1)
#> params (as set within xgb.train):
#> eta = "0.3", max_depth = "3", gamma = "0", colsample_bytree = "1", min_child_weight = "9", subsample = "1", objective = "multi:softprob", num_class = "4", nthread = "1", validate_parameters = "TRUE"
#> xgb.attributes:
#> niter
#> callbacks:
#> cb.evaluation.log()
#> # of features: 26
#> niter: 1000
#> nfeatures : 26
#> evaluation_log:
#> iter training_merror
#> 1 0.320942
#> 2 0.301778
#> ---
#> 999 0.010390
#> 1000 0.010390
现在假设我们有一些全新的数据。您可以predict()
在训练有素的工作流程上使用新数据。
brand_new_data <- hpc_data[5, -8]
brand_new_data
#> # A tibble: 1 x 7
#> protocol compounds input_fields iterations num_pending hour day
#> <fct> <dbl> <dbl> <dbl> <dbl> <dbl> <fct>
#> 1 E 100 82 20 0 10.4 Fri
predict(trained_wf, new_data = brand_new_data)
#> # A tibble: 1 x 1
#> .pred_class
#> <fct>
#> 1 VF
由reprex 包于 2020-07-17 创建(v0.3.0)