1

我尝试使用updatein 中的函数覆盖默认调整值tidymodels,但值无法更新。

例如,在下面的代码中,我想将 的范围min_n从默认的 2 到 40 更改为 30 到 50。但是,值min_n保持在 2 和 40。

library(tidymodels)
#> -- Attaching packages --------------------------------------------------------------------------------------------------------------------------- tidymodels 0.1.0 --
#> v broom     0.7.0      v recipes   0.1.13
#> v dials     0.0.8      v rsample   0.0.7 
#> v dplyr     1.0.0      v tibble    3.0.3 
#> v ggplot2   3.3.2      v tune      0.1.1 
#> v infer     0.5.2      v workflows 0.1.2 
#> v parsnip   0.1.2      v yardstick 0.0.7 
#> v purrr     0.3.4
#> -- Conflicts ------------------------------------------------------------------------------------------------------------------------------ tidymodels_conflicts() --
#> x purrr::discard() masks scales::discard()
#> x dplyr::filter()  masks stats::filter()
#> x dplyr::lag()     masks stats::lag()
#> x recipes::step()  masks stats::step()
rf <- decision_tree(cost_complexity = tune(), tree_depth = tune(), min_n = tune()) %>%
  set_mode("classification") %>%
  set_engine("rpart")

rf_wf <- workflow() %>% 
  add_model(rf) %>%
  add_formula(class ~ .)

param <- rf %>% parameters()
param %>% update(min_n = min_n(range = c(30L, 50L)))
#> Collection of 3 parameters for tuning
#> 
#>               id  parameter type object class
#>  cost_complexity cost_complexity    nparam[+]
#>       tree_depth      tree_depth    nparam[+]
#>            min_n           min_n    nparam[+]

rf_grid <- grid_regular(param, levels = 2)
rf_grid
#> # A tibble: 8 x 3
#>   cost_complexity tree_depth min_n
#>             <dbl>      <int> <int>
#> 1    0.0000000001          1     2
#> 2    0.1                   1     2
#> 3    0.0000000001         15     2
#> 4    0.1                  15     2
#> 5    0.0000000001          1    40
#> 6    0.1                   1    40
#> 7    0.0000000001         15    40
#> 8    0.1                  15    40

reprex 包(v0.3.0)于 2020-07-26 创建

4

1 回答 1

1

update方法返回一个新的参数对象——它不会更新您就地传递的值。你需要做

newparam <- param %>% update(min_n = min_n(range = c(30L, 50L)))
grid_regular(newparam, levels = 2)

#   cost_complexity tree_depth min_n
#             <dbl>      <int> <int>
# 1    0.0000000001          1    30
# 2    0.1                   1    30
# 3    0.0000000001         15    30
# 4    0.1                  15    30
# 5    0.0000000001          1    50
# 6    0.1                   1    50
# 7    0.0000000001         15    50
# 8    0.1                  15    50
于 2020-07-26T01:54:32.217 回答