1

我正在学习新的 tidymodels 框架的原理,所以我可能会误解一些基本的东西。

我提供了一个独立的示例,其中包含一个真实的(从我的工作中获取的)数据集。请把它当作一个给定的,我需要使用除了最近的观察之外的所有观察作为训练集,并且只有最近的观察作为测试集(所以在这种情况下,测试集只是一个观察)。

但是,我收到一个我无法破译的错误。任何建议表示赞赏。

谢谢!

library(tidyverse) 

library(tidymodels)


df_ini <- structure(list(year = c(1998, 2002, 2004, 2005, 2006, 2007, 2008, 
2009, 2010, 2011, 2012, 2013, 2014, 2015, 2016, 2017, 2018), 
    capital_n1132g_lag_1 = c(3446.5, 4091.1, 3655.1, 3633.3, 
    3616.2, 3450.7, 3596.8, 3867.2, 3372.5, 3722.9, 3808.5, 4005.6, 
    3718.6, 3467.9, 4214.2, 4237.4, 4450.2), capital_n117g_lag_1 = c(4920.9, 
    7810.6, 8560.3, 8679.9, 8938.9, 9823.8, 10467.1, 11047.1, 
    11554.3, 11849.9, 13465.4, 13927.5, 15510.2, 15754.4, 16584.7, 
    17647.1, 18273.8), capital_n11mg_lag_1 = c(16846, 19605, 
    19381.2, 19433.5, 20051.6, 20569.8, 22646.1, 23674.5, 21200.6, 
    20919.6, 23157.7, 23520.7, 24057.7, 23832.8, 25019.2, 27608.2, 
    29790.1), employment_be_lag_1 = c(2834.42, 2839.72, 2765.53, 
    2731.08, 2709.59, 2708.39, 2774.06, 2795.6, 2703.36, 2668.1, 
    2705.1, 2731.67, 2727.16, 2725.66, 2735.69, 2750.52, 2782.9
    ), employment_c_lag_1 = c(2612.76, 2623.69, 2552.89, 2518.57, 
    2496.98, 2499.54, 2558.88, 2578, 2483.97, 2447.65, 2483.1, 
    2507.41, 2500.94, 2499.6, 2511.75, 2523.97, 2555.48), employment_j_lag_1 = c(292.93, 
    389.2, 389.45, 387.53, 384.64, 389.29, 385.77, 392.86, 383.91, 
    392.18, 410.85, 419.75, 427.59, 438.96, 440.33, 460.84, 473.4
    ), employment_k_lag_1 = c(505.33, 507.12, 510.25, 504.63, 
    515.39, 523.45, 536.6, 550.14, 546.68, 539.96, 536.58, 534.98, 
    524.13, 518.89, 511.57, 505.32, 496.41), employment_mn_lag_1 = c(945.59, 
    1217.96, 1289.55, 1365.29, 1425.81, 1537.88, 1622.95, 1727.76, 
    1704.65, 1762.55, 1838.16, 1896.09, 1929.09, 1950.02, 1968.83, 
    2021.51, 2109.71), employment_oq_lag_1 = c(3065.87, 3191.75, 
    3280.36, 3317.09, 3401.65, 3476.63, 3508.01, 3577.75, 3683.85, 
    3759.23, 3798.35, 3850.17, 3877.24, 3924.06, 4002.74, 4095.59, 
    4171.72), employment_total_lag_1 = c(14509.58, 15127.99, 
    15212.11, 15307.28, 15491.61, 15762.92, 16050.92, 16356.53, 
    16269.97, 16392.87, 16647.79, 16820.66, 16879.06, 17039.6, 
    17142.13, 17365.32, 17650.21), gdp_b1gq_lag_1 = c(187849.7, 
    220525, 231862.5, 242348.3, 254075, 267824.4, 283978, 293761.9, 
    288044.1, 295896.6, 310128.6, 318653.1, 323910.2, 333146.1, 
    344269.3, 357608, 369341.3), gdp_p3_lag_1 = c(139695.2, 161175.8, 
    169405.6, 176316.4, 185871.1, 194102, 200944.4, 208857.1, 
    213630.1, 218947.2, 227250.8, 233638.1, 238329.3, 243860.6, 
    249404.3, 257166.5, 265900.2), gdp_p61_lag_1 = c(50117.6, 
    71948.6, 74346.9, 83074.9, 90010.4, 100076.8, 110157.2, 113368.1, 
    91435.3, 111997.3, 123526.3, 125801.2, 123657.1, 126109.3, 
    129183.6, 131524, 140057.8), gdp_p62_lag_1 = c(19441, 26444.4, 
    28995.1, 30507, 33520.2, 36089.5, 39104, 43056.8, 38781.9, 
    39685.8, 43784.1, 46187.6, 49444.7, 51746, 53585.8, 55885.5, 
    59584.7), price_index_lag_1 = c(1.2, 2.3, 1.3, 2, 2.1, 1.7, 
    2.2, 3.2, 0.4, 1.7, 3.6, 2.6, 2.1, 1.5, 0.8, 1, 2.2), value_be_lag_1 = c(40533.1, 
    48207.1, 48673.2, 50737.6, 52955.2, 56872.4, 60864.9, 61029, 
    56837.8, 58433.6, 61443, 63655.1, 64132.3, 65542.6, 67495.4, 
    71152.6, 72698.8), value_c_lag_1 = c(33441.8, 40446.6, 40467.4, 
    42014.6, 44229, 47735.5, 51552.4, 51165.9, 47129.7, 48759.3, 
    51467.7, 53234.6, 53431.4, 55169, 57458.7, 60962.8, 62196
    ), value_j_lag_1 = c(5483.7, 7326.1, 7934.1, 7756.1, 8134.2, 
    8378.8, 8532.3, 8740, 8493.9, 8518.9, 9217.1, 9405.1, 9802.1, 
    10361.4, 10695.4, 11455.3, 11720.6), value_k_lag_1 = c(9210.6, 
    9977.3, 10146.9, 10541.9, 11005.3, 11912.3, 13102.7, 13205.2, 
    12123.9, 12113.2, 12952.8, 12254.9, 12796.6, 12962.4, 13482.9, 
    13236.4, 13744.1), value_mn_lag_1 = c(10444, 14061.4, 15706.6, 
    16569.1, 18008.7, 19576.6, 21317, 23189.8, 22490, 23255.2, 
    24895.4, 25988.7, 26998.2, 28027.3, 29207.9, 30737.7, 32259.6
    ), value_oq_lag_1 = c(29902.7, 34179.2, 36126.8, 37329.6, 
    38288.8, 40003.1, 41511.4, 43761.3, 45817.8, 46996.6, 47980.9, 
    49381.5, 50261.7, 51624.3, 53715, 55926.4, 57637.1), value_total_lag_1 = c(167323.4, 
    197076.7, 207247.6, 216098.3, 225888.1, 239076, 253604.6, 
    262414.7, 256671, 263633.5, 276404, 283548.2, 288624.3, 297230.1, 
    307037.7, 318952.7, 329396.1), capital_n1132g_lag_2 = c(3599.2, 
    3996.9, 3638.4, 3655.1, 3633.3, 3616.2, 3450.7, 3596.8, 3867.2, 
    3372.5, 3722.9, 3808.5, 4005.6, 3718.6, 3467.9, 4214.2, 4237.4
    ), capital_n117g_lag_2 = c(4636.2, 7008.5, 8369.6, 8560.3, 
    8679.9, 8938.9, 9823.8, 10467.1, 11047.1, 11554.3, 11849.9, 
    13465.4, 13927.5, 15510.2, 15754.4, 16584.7, 17647.1), capital_n11mg_lag_2 = c(17181.5, 
    19677.8, 18749.6, 19381.2, 19433.5, 20051.6, 20569.8, 22646.1, 
    23674.5, 21200.6, 20919.6, 23157.7, 23520.7, 24057.7, 23832.8, 
    25019.2, 27608.2), employment_be_lag_2 = c(2870.33, 2840.19, 
    2775.22, 2765.53, 2731.08, 2709.59, 2708.39, 2774.06, 2795.6, 
    2703.36, 2668.1, 2705.1, 2731.67, 2727.16, 2725.66, 2735.69, 
    2750.52), employment_c_lag_2 = c(2626.2, 2621.08, 2562.53, 
    2552.89, 2518.57, 2496.98, 2499.54, 2558.88, 2578, 2483.97, 
    2447.65, 2483.1, 2507.41, 2500.94, 2499.6, 2511.75, 2523.97
    ), employment_j_lag_2 = c(275.08, 374.56, 400.75, 389.45, 
    387.53, 384.64, 389.29, 385.77, 392.86, 383.91, 392.18, 410.85, 
    419.75, 427.59, 438.96, 440.33, 460.84), employment_k_lag_2 = c(500.9, 
    505.13, 502.42, 510.25, 504.63, 515.39, 523.45, 536.6, 550.14, 
    546.68, 539.96, 536.58, 534.98, 524.13, 518.89, 511.57, 505.32
    ), employment_mn_lag_2 = c(904.38, 1143.78, 1248.01, 1289.55, 
    1365.29, 1425.81, 1537.88, 1622.95, 1727.76, 1704.65, 1762.55, 
    1838.16, 1896.09, 1929.09, 1950.02, 1968.83, 2021.51), employment_oq_lag_2 = c(3028.85, 
    3162.77, 3241.36, 3280.36, 3317.09, 3401.65, 3476.63, 3508.01, 
    3577.75, 3683.85, 3759.23, 3798.35, 3850.17, 3877.24, 3924.06, 
    4002.74, 4095.59), employment_total_lag_2 = c(14404.29, 15019.87, 
    15113.52, 15212.11, 15307.28, 15491.61, 15762.92, 16050.92, 
    16356.53, 16269.97, 16392.87, 16647.79, 16820.66, 16879.06, 
    17039.6, 17142.13, 17365.32), gdp_b1gq_lag_2 = c(186928.7, 
    213606.4, 226735.3, 231862.5, 242348.3, 254075, 267824.4, 
    283978, 293761.9, 288044.1, 295896.6, 310128.6, 318653.1, 
    323910.2, 333146.1, 344269.3, 357608), gdp_p3_lag_2 = c(140335.8, 
    156117.3, 164107.8, 169405.6, 176316.4, 185871.1, 194102, 
    200944.4, 208857.1, 213630.1, 218947.2, 227250.8, 233638.1, 
    238329.3, 243860.6, 249404.3, 257166.5), gdp_p61_lag_2 = c(44541.4, 
    67701.6, 74691.6, 74346.9, 83074.9, 90010.4, 100076.8, 110157.2, 
    113368.1, 91435.3, 111997.3, 123526.3, 125801.2, 123657.1, 
    126109.3, 129183.6, 131524), gdp_p62_lag_2 = c(19504.2, 24888.9, 
    28063.4, 28995.1, 30507, 33520.2, 36089.5, 39104, 43056.8, 
    38781.9, 39685.8, 43784.1, 46187.6, 49444.7, 51746, 53585.8, 
    55885.5), value_be_lag_2 = c(40076.7, 46109.4, 47967.1, 48673.2, 
    50737.6, 52955.2, 56872.4, 60864.9, 61029, 56837.8, 58433.6, 
    61443, 63655.1, 64132.3, 65542.6, 67495.4, 71152.6), value_c_lag_2 = c(32955.4, 
    38908.4, 40192.9, 40467.4, 42014.6, 44229, 47735.5, 51552.4, 
    51165.9, 47129.7, 48759.3, 51467.7, 53234.6, 53431.4, 55169, 
    57458.7, 60962.8), value_j_lag_2 = c(5576.8, 6313.9, 7737.1, 
    7934.1, 7756.1, 8134.2, 8378.8, 8532.3, 8740, 8493.9, 8518.9, 
    9217.1, 9405.1, 9802.1, 10361.4, 10695.4, 11455.3), value_k_lag_2 = c(9191, 
    10458, 10225.2, 10146.9, 10541.9, 11005.3, 11912.3, 13102.7, 
    13205.2, 12123.9, 12113.2, 12952.8, 12254.9, 12796.6, 12962.4, 
    13482.9, 13236.4), value_mn_lag_2 = c(10092, 12942.5, 15074, 
    15706.6, 16569.1, 18008.7, 19576.6, 21317, 23189.8, 22490, 
    23255.2, 24895.4, 25988.7, 26998.2, 28027.3, 29207.9, 30737.7
    ), value_oq_lag_2 = c(30224.3, 33251.5, 35065.6, 36126.8, 
    37329.6, 38288.8, 40003.1, 41511.4, 43761.3, 45817.8, 46996.6, 
    47980.9, 49381.5, 50261.7, 51624.3, 53715, 55926.4), value_total_lag_2 = c(167141.8, 
    190624.9, 202353.5, 207247.6, 216098.3, 225888.1, 239076, 
    253604.6, 262414.7, 256671, 263633.5, 276404, 283548.2, 288624.3, 
    297230.1, 307037.7, 318952.7), berd = c(2146.085, 3130.884, 
    3556.479, 4207.669, 4448.676, 4845.861, 5232.63, 5092.902, 
    5520.422, 5692.841, 6540.457, 6778.42, 7324.679, 7498.488, 
    7824.51, 7888.444, 8461.72)), row.names = c(NA, -17L), class = c("tbl_df", 
"tbl", "data.frame"))





set.seed(1234)  ## to make the results reproducible






## I need a particular custom split of my dataset: the test set consists of only the most recent observation, whereas all the rest is the training set

## see https://github.com/tidymodels/rsample/issues/158


indices <-
  list(analysis   = seq(nrow(df_ini)-1), 
       assessment = nrow(df_ini)
       )

df_split <- make_splits(indices, df_ini)


## df_split <- initial_split(df_ini) ## with the default splitting,
## ## the code works

df_train <- training(df_split)
df_test <- testing(df_split)

folded_data <- vfold_cv(df_train,3)



glmnet_recipe <- 
    recipe(formula = berd ~ ., data = df_train) %>%
    update_role(year, new_role = "ID") %>%
  step_zv(all_predictors()) %>% 
  step_normalize(all_predictors(), -all_nominal()) 

glmnet_spec <- 
  linear_reg(penalty = tune(), mixture = tune()) %>% 
  set_mode("regression") %>% 
  set_engine("glmnet") 

glmnet_workflow <- 
  workflow() %>% 
  add_recipe(glmnet_recipe) %>% 
  add_model(glmnet_spec) 




glmnet_grid <- tidyr::crossing(penalty = 10^seq(-6, -1, length.out = 20), mixture = c(0.05, 
    0.2, 0.4, 0.6, 0.8, 1)) 

glmnet_tune <- 
  tune_grid(glmnet_workflow, resamples = folded_data, grid = glmnet_grid,control = control_grid(save_pred = TRUE) ) 

print(collect_metrics(glmnet_tune))
#> # A tibble: 240 x 8
#>       penalty mixture .metric .estimator    mean     n std_err .config 
#>         <dbl>   <dbl> <chr>   <chr>        <dbl> <int>   <dbl> <chr>   
#>  1 0.000001      0.05 rmse    standard   375.        3 48.9    Model001
#>  2 0.000001      0.05 rsq     standard     0.929     3  0.0420 Model001
#>  3 0.00000183    0.05 rmse    standard   375.        3 48.9    Model002
#>  4 0.00000183    0.05 rsq     standard     0.929     3  0.0420 Model002
#>  5 0.00000336    0.05 rmse    standard   375.        3 48.9    Model003
#>  6 0.00000336    0.05 rsq     standard     0.929     3  0.0420 Model003
#>  7 0.00000616    0.05 rmse    standard   375.        3 48.9    Model004
#>  8 0.00000616    0.05 rsq     standard     0.929     3  0.0420 Model004
#>  9 0.0000113     0.05 rmse    standard   375.        3 48.9    Model005
#> 10 0.0000113     0.05 rsq     standard     0.929     3  0.0420 Model005
#> # … with 230 more rows

print(show_best(glmnet_tune, "rmse"))
#> # A tibble: 5 x 8
#>      penalty mixture .metric .estimator  mean     n std_err .config 
#>        <dbl>   <dbl> <chr>   <chr>      <dbl> <int>   <dbl> <chr>   
#> 1 0.000001      0.05 rmse    standard    375.     3    48.9 Model001
#> 2 0.00000183    0.05 rmse    standard    375.     3    48.9 Model002
#> 3 0.00000336    0.05 rmse    standard    375.     3    48.9 Model003
#> 4 0.00000616    0.05 rmse    standard    375.     3    48.9 Model004
#> 5 0.0000113     0.05 rmse    standard    375.     3    48.9 Model005

best_net <- select_best(glmnet_tune, "rmse")


final_net <- finalize_workflow(
  glmnet_workflow,
  best_net
)


final_res_net <- last_fit(final_net, df_split)
#> x : internal: Error in data.frame(..., check.names = FALSE): arguments imply...
#> Warning: All models failed in [fit_resamples()]. See the `.notes` column.


print(final_res_net)
#> Warning: This tuning result has notes. Example notes on model fitting include:
#> internal: Error in data.frame(..., check.names = FALSE): arguments imply differing number of rows: 2, 0
#> # Resampling results
#> # Monte Carlo cross-validation (0.94/0.059) with 1 resamples  
#> # A tibble: 1 x 5
#>   splits         id               .metrics .notes           .predictions
#>   <list>         <chr>            <list>   <list>           <list>      
#> 1 <split [16/1]> train/test split <NULL>   <tibble [1 × 1]> <NULL>

final_fit <- final_res_net %>%
    collect_predictions()

reprex 包(v0.3.0.9001)于 2020 年 10 月 15 日创建

4

0 回答 0