我正在学习新的 tidymodels 框架的原理,所以我可能会误解一些基本的东西。
我提供了一个独立的示例,其中包含一个真实的(从我的工作中获取的)数据集。请把它当作一个给定的,我需要使用除了最近的观察之外的所有观察作为训练集,并且只有最近的观察作为测试集(所以在这种情况下,测试集只是一个观察)。
但是,我收到一个我无法破译的错误。任何建议表示赞赏。
谢谢!
library(tidyverse)
library(tidymodels)
df_ini <- structure(list(year = c(1998, 2002, 2004, 2005, 2006, 2007, 2008,
2009, 2010, 2011, 2012, 2013, 2014, 2015, 2016, 2017, 2018),
capital_n1132g_lag_1 = c(3446.5, 4091.1, 3655.1, 3633.3,
3616.2, 3450.7, 3596.8, 3867.2, 3372.5, 3722.9, 3808.5, 4005.6,
3718.6, 3467.9, 4214.2, 4237.4, 4450.2), capital_n117g_lag_1 = c(4920.9,
7810.6, 8560.3, 8679.9, 8938.9, 9823.8, 10467.1, 11047.1,
11554.3, 11849.9, 13465.4, 13927.5, 15510.2, 15754.4, 16584.7,
17647.1, 18273.8), capital_n11mg_lag_1 = c(16846, 19605,
19381.2, 19433.5, 20051.6, 20569.8, 22646.1, 23674.5, 21200.6,
20919.6, 23157.7, 23520.7, 24057.7, 23832.8, 25019.2, 27608.2,
29790.1), employment_be_lag_1 = c(2834.42, 2839.72, 2765.53,
2731.08, 2709.59, 2708.39, 2774.06, 2795.6, 2703.36, 2668.1,
2705.1, 2731.67, 2727.16, 2725.66, 2735.69, 2750.52, 2782.9
), employment_c_lag_1 = c(2612.76, 2623.69, 2552.89, 2518.57,
2496.98, 2499.54, 2558.88, 2578, 2483.97, 2447.65, 2483.1,
2507.41, 2500.94, 2499.6, 2511.75, 2523.97, 2555.48), employment_j_lag_1 = c(292.93,
389.2, 389.45, 387.53, 384.64, 389.29, 385.77, 392.86, 383.91,
392.18, 410.85, 419.75, 427.59, 438.96, 440.33, 460.84, 473.4
), employment_k_lag_1 = c(505.33, 507.12, 510.25, 504.63,
515.39, 523.45, 536.6, 550.14, 546.68, 539.96, 536.58, 534.98,
524.13, 518.89, 511.57, 505.32, 496.41), employment_mn_lag_1 = c(945.59,
1217.96, 1289.55, 1365.29, 1425.81, 1537.88, 1622.95, 1727.76,
1704.65, 1762.55, 1838.16, 1896.09, 1929.09, 1950.02, 1968.83,
2021.51, 2109.71), employment_oq_lag_1 = c(3065.87, 3191.75,
3280.36, 3317.09, 3401.65, 3476.63, 3508.01, 3577.75, 3683.85,
3759.23, 3798.35, 3850.17, 3877.24, 3924.06, 4002.74, 4095.59,
4171.72), employment_total_lag_1 = c(14509.58, 15127.99,
15212.11, 15307.28, 15491.61, 15762.92, 16050.92, 16356.53,
16269.97, 16392.87, 16647.79, 16820.66, 16879.06, 17039.6,
17142.13, 17365.32, 17650.21), gdp_b1gq_lag_1 = c(187849.7,
220525, 231862.5, 242348.3, 254075, 267824.4, 283978, 293761.9,
288044.1, 295896.6, 310128.6, 318653.1, 323910.2, 333146.1,
344269.3, 357608, 369341.3), gdp_p3_lag_1 = c(139695.2, 161175.8,
169405.6, 176316.4, 185871.1, 194102, 200944.4, 208857.1,
213630.1, 218947.2, 227250.8, 233638.1, 238329.3, 243860.6,
249404.3, 257166.5, 265900.2), gdp_p61_lag_1 = c(50117.6,
71948.6, 74346.9, 83074.9, 90010.4, 100076.8, 110157.2, 113368.1,
91435.3, 111997.3, 123526.3, 125801.2, 123657.1, 126109.3,
129183.6, 131524, 140057.8), gdp_p62_lag_1 = c(19441, 26444.4,
28995.1, 30507, 33520.2, 36089.5, 39104, 43056.8, 38781.9,
39685.8, 43784.1, 46187.6, 49444.7, 51746, 53585.8, 55885.5,
59584.7), price_index_lag_1 = c(1.2, 2.3, 1.3, 2, 2.1, 1.7,
2.2, 3.2, 0.4, 1.7, 3.6, 2.6, 2.1, 1.5, 0.8, 1, 2.2), value_be_lag_1 = c(40533.1,
48207.1, 48673.2, 50737.6, 52955.2, 56872.4, 60864.9, 61029,
56837.8, 58433.6, 61443, 63655.1, 64132.3, 65542.6, 67495.4,
71152.6, 72698.8), value_c_lag_1 = c(33441.8, 40446.6, 40467.4,
42014.6, 44229, 47735.5, 51552.4, 51165.9, 47129.7, 48759.3,
51467.7, 53234.6, 53431.4, 55169, 57458.7, 60962.8, 62196
), value_j_lag_1 = c(5483.7, 7326.1, 7934.1, 7756.1, 8134.2,
8378.8, 8532.3, 8740, 8493.9, 8518.9, 9217.1, 9405.1, 9802.1,
10361.4, 10695.4, 11455.3, 11720.6), value_k_lag_1 = c(9210.6,
9977.3, 10146.9, 10541.9, 11005.3, 11912.3, 13102.7, 13205.2,
12123.9, 12113.2, 12952.8, 12254.9, 12796.6, 12962.4, 13482.9,
13236.4, 13744.1), value_mn_lag_1 = c(10444, 14061.4, 15706.6,
16569.1, 18008.7, 19576.6, 21317, 23189.8, 22490, 23255.2,
24895.4, 25988.7, 26998.2, 28027.3, 29207.9, 30737.7, 32259.6
), value_oq_lag_1 = c(29902.7, 34179.2, 36126.8, 37329.6,
38288.8, 40003.1, 41511.4, 43761.3, 45817.8, 46996.6, 47980.9,
49381.5, 50261.7, 51624.3, 53715, 55926.4, 57637.1), value_total_lag_1 = c(167323.4,
197076.7, 207247.6, 216098.3, 225888.1, 239076, 253604.6,
262414.7, 256671, 263633.5, 276404, 283548.2, 288624.3, 297230.1,
307037.7, 318952.7, 329396.1), capital_n1132g_lag_2 = c(3599.2,
3996.9, 3638.4, 3655.1, 3633.3, 3616.2, 3450.7, 3596.8, 3867.2,
3372.5, 3722.9, 3808.5, 4005.6, 3718.6, 3467.9, 4214.2, 4237.4
), capital_n117g_lag_2 = c(4636.2, 7008.5, 8369.6, 8560.3,
8679.9, 8938.9, 9823.8, 10467.1, 11047.1, 11554.3, 11849.9,
13465.4, 13927.5, 15510.2, 15754.4, 16584.7, 17647.1), capital_n11mg_lag_2 = c(17181.5,
19677.8, 18749.6, 19381.2, 19433.5, 20051.6, 20569.8, 22646.1,
23674.5, 21200.6, 20919.6, 23157.7, 23520.7, 24057.7, 23832.8,
25019.2, 27608.2), employment_be_lag_2 = c(2870.33, 2840.19,
2775.22, 2765.53, 2731.08, 2709.59, 2708.39, 2774.06, 2795.6,
2703.36, 2668.1, 2705.1, 2731.67, 2727.16, 2725.66, 2735.69,
2750.52), employment_c_lag_2 = c(2626.2, 2621.08, 2562.53,
2552.89, 2518.57, 2496.98, 2499.54, 2558.88, 2578, 2483.97,
2447.65, 2483.1, 2507.41, 2500.94, 2499.6, 2511.75, 2523.97
), employment_j_lag_2 = c(275.08, 374.56, 400.75, 389.45,
387.53, 384.64, 389.29, 385.77, 392.86, 383.91, 392.18, 410.85,
419.75, 427.59, 438.96, 440.33, 460.84), employment_k_lag_2 = c(500.9,
505.13, 502.42, 510.25, 504.63, 515.39, 523.45, 536.6, 550.14,
546.68, 539.96, 536.58, 534.98, 524.13, 518.89, 511.57, 505.32
), employment_mn_lag_2 = c(904.38, 1143.78, 1248.01, 1289.55,
1365.29, 1425.81, 1537.88, 1622.95, 1727.76, 1704.65, 1762.55,
1838.16, 1896.09, 1929.09, 1950.02, 1968.83, 2021.51), employment_oq_lag_2 = c(3028.85,
3162.77, 3241.36, 3280.36, 3317.09, 3401.65, 3476.63, 3508.01,
3577.75, 3683.85, 3759.23, 3798.35, 3850.17, 3877.24, 3924.06,
4002.74, 4095.59), employment_total_lag_2 = c(14404.29, 15019.87,
15113.52, 15212.11, 15307.28, 15491.61, 15762.92, 16050.92,
16356.53, 16269.97, 16392.87, 16647.79, 16820.66, 16879.06,
17039.6, 17142.13, 17365.32), gdp_b1gq_lag_2 = c(186928.7,
213606.4, 226735.3, 231862.5, 242348.3, 254075, 267824.4,
283978, 293761.9, 288044.1, 295896.6, 310128.6, 318653.1,
323910.2, 333146.1, 344269.3, 357608), gdp_p3_lag_2 = c(140335.8,
156117.3, 164107.8, 169405.6, 176316.4, 185871.1, 194102,
200944.4, 208857.1, 213630.1, 218947.2, 227250.8, 233638.1,
238329.3, 243860.6, 249404.3, 257166.5), gdp_p61_lag_2 = c(44541.4,
67701.6, 74691.6, 74346.9, 83074.9, 90010.4, 100076.8, 110157.2,
113368.1, 91435.3, 111997.3, 123526.3, 125801.2, 123657.1,
126109.3, 129183.6, 131524), gdp_p62_lag_2 = c(19504.2, 24888.9,
28063.4, 28995.1, 30507, 33520.2, 36089.5, 39104, 43056.8,
38781.9, 39685.8, 43784.1, 46187.6, 49444.7, 51746, 53585.8,
55885.5), value_be_lag_2 = c(40076.7, 46109.4, 47967.1, 48673.2,
50737.6, 52955.2, 56872.4, 60864.9, 61029, 56837.8, 58433.6,
61443, 63655.1, 64132.3, 65542.6, 67495.4, 71152.6), value_c_lag_2 = c(32955.4,
38908.4, 40192.9, 40467.4, 42014.6, 44229, 47735.5, 51552.4,
51165.9, 47129.7, 48759.3, 51467.7, 53234.6, 53431.4, 55169,
57458.7, 60962.8), value_j_lag_2 = c(5576.8, 6313.9, 7737.1,
7934.1, 7756.1, 8134.2, 8378.8, 8532.3, 8740, 8493.9, 8518.9,
9217.1, 9405.1, 9802.1, 10361.4, 10695.4, 11455.3), value_k_lag_2 = c(9191,
10458, 10225.2, 10146.9, 10541.9, 11005.3, 11912.3, 13102.7,
13205.2, 12123.9, 12113.2, 12952.8, 12254.9, 12796.6, 12962.4,
13482.9, 13236.4), value_mn_lag_2 = c(10092, 12942.5, 15074,
15706.6, 16569.1, 18008.7, 19576.6, 21317, 23189.8, 22490,
23255.2, 24895.4, 25988.7, 26998.2, 28027.3, 29207.9, 30737.7
), value_oq_lag_2 = c(30224.3, 33251.5, 35065.6, 36126.8,
37329.6, 38288.8, 40003.1, 41511.4, 43761.3, 45817.8, 46996.6,
47980.9, 49381.5, 50261.7, 51624.3, 53715, 55926.4), value_total_lag_2 = c(167141.8,
190624.9, 202353.5, 207247.6, 216098.3, 225888.1, 239076,
253604.6, 262414.7, 256671, 263633.5, 276404, 283548.2, 288624.3,
297230.1, 307037.7, 318952.7), berd = c(2146.085, 3130.884,
3556.479, 4207.669, 4448.676, 4845.861, 5232.63, 5092.902,
5520.422, 5692.841, 6540.457, 6778.42, 7324.679, 7498.488,
7824.51, 7888.444, 8461.72)), row.names = c(NA, -17L), class = c("tbl_df",
"tbl", "data.frame"))
set.seed(1234) ## to make the results reproducible
## I need a particular custom split of my dataset: the test set consists of only the most recent observation, whereas all the rest is the training set
## see https://github.com/tidymodels/rsample/issues/158
indices <-
list(analysis = seq(nrow(df_ini)-1),
assessment = nrow(df_ini)
)
df_split <- make_splits(indices, df_ini)
## df_split <- initial_split(df_ini) ## with the default splitting,
## ## the code works
df_train <- training(df_split)
df_test <- testing(df_split)
folded_data <- vfold_cv(df_train,3)
glmnet_recipe <-
recipe(formula = berd ~ ., data = df_train) %>%
update_role(year, new_role = "ID") %>%
step_zv(all_predictors()) %>%
step_normalize(all_predictors(), -all_nominal())
glmnet_spec <-
linear_reg(penalty = tune(), mixture = tune()) %>%
set_mode("regression") %>%
set_engine("glmnet")
glmnet_workflow <-
workflow() %>%
add_recipe(glmnet_recipe) %>%
add_model(glmnet_spec)
glmnet_grid <- tidyr::crossing(penalty = 10^seq(-6, -1, length.out = 20), mixture = c(0.05,
0.2, 0.4, 0.6, 0.8, 1))
glmnet_tune <-
tune_grid(glmnet_workflow, resamples = folded_data, grid = glmnet_grid,control = control_grid(save_pred = TRUE) )
print(collect_metrics(glmnet_tune))
#> # A tibble: 240 x 8
#> penalty mixture .metric .estimator mean n std_err .config
#> <dbl> <dbl> <chr> <chr> <dbl> <int> <dbl> <chr>
#> 1 0.000001 0.05 rmse standard 375. 3 48.9 Model001
#> 2 0.000001 0.05 rsq standard 0.929 3 0.0420 Model001
#> 3 0.00000183 0.05 rmse standard 375. 3 48.9 Model002
#> 4 0.00000183 0.05 rsq standard 0.929 3 0.0420 Model002
#> 5 0.00000336 0.05 rmse standard 375. 3 48.9 Model003
#> 6 0.00000336 0.05 rsq standard 0.929 3 0.0420 Model003
#> 7 0.00000616 0.05 rmse standard 375. 3 48.9 Model004
#> 8 0.00000616 0.05 rsq standard 0.929 3 0.0420 Model004
#> 9 0.0000113 0.05 rmse standard 375. 3 48.9 Model005
#> 10 0.0000113 0.05 rsq standard 0.929 3 0.0420 Model005
#> # … with 230 more rows
print(show_best(glmnet_tune, "rmse"))
#> # A tibble: 5 x 8
#> penalty mixture .metric .estimator mean n std_err .config
#> <dbl> <dbl> <chr> <chr> <dbl> <int> <dbl> <chr>
#> 1 0.000001 0.05 rmse standard 375. 3 48.9 Model001
#> 2 0.00000183 0.05 rmse standard 375. 3 48.9 Model002
#> 3 0.00000336 0.05 rmse standard 375. 3 48.9 Model003
#> 4 0.00000616 0.05 rmse standard 375. 3 48.9 Model004
#> 5 0.0000113 0.05 rmse standard 375. 3 48.9 Model005
best_net <- select_best(glmnet_tune, "rmse")
final_net <- finalize_workflow(
glmnet_workflow,
best_net
)
final_res_net <- last_fit(final_net, df_split)
#> x : internal: Error in data.frame(..., check.names = FALSE): arguments imply...
#> Warning: All models failed in [fit_resamples()]. See the `.notes` column.
print(final_res_net)
#> Warning: This tuning result has notes. Example notes on model fitting include:
#> internal: Error in data.frame(..., check.names = FALSE): arguments imply differing number of rows: 2, 0
#> # Resampling results
#> # Monte Carlo cross-validation (0.94/0.059) with 1 resamples
#> # A tibble: 1 x 5
#> splits id .metrics .notes .predictions
#> <list> <chr> <list> <list> <list>
#> 1 <split [16/1]> train/test split <NULL> <tibble [1 × 1]> <NULL>
final_fit <- final_res_net %>%
collect_predictions()
由reprex 包(v0.3.0.9001)于 2020 年 10 月 15 日创建