2

嗨,我正在尝试使用 tidymodels 制作线性回归模型的示例,我设法使用框架正确拟合模型,并使用 collect_metrics() 和 collect_predictions() 在工作流中对其进行测试。但是,当我尝试使用该模型对新数据进行预测时,我无法让它发挥作用。我正在尝试调整这个例子


rf_wflow_final_fit <- fit(rf_wflow_final, data = dia_train)

dia_rec3     <- pull_workflow_prepped_recipe(rf_wflow_final_fit)
rf_final_fit <- pull_workflow_fit(rf_wflow_final_fit)

dia_test$.pred <- predict(rf_final_fit, 
                          new_data = bake(dia_rec3, dia_test))$.pred
dia_test$logprice <- log(dia_test$price)

metrics(dia_test, truth = logprice, estimate = .pred)
#> # A tibble: 3 x 3
#>   .metric .estimator .estimate
#>   <chr>   <chr>          <dbl>
#> 1 rmse    standard      0.113 
#> 2 rsq     standard      0.988 
#> 3 mae     standard      0.0846

这就是我正在做的事情:

data("diamonds")
set.seed(234589)
diamonds_split <- initial_split(diamonds, prop = 4/5)

diamonds_train <- training(diamonds_split)
diamonds_test <- testing(diamonds_split)

diamonds_recipe <- 
  recipe(price ~ ., data = diamonds_train) %>%
  step_log(all_outcomes()) %>%
  step_normalize(all_predictors(), -all_nominal()) %>%
  step_dummy(all_nominal()) %>%
  step_poly(carat, degree = 2)

preprocesados <- prep(diamonds_recipe)

lr_model <- 
  linear_reg()%>%
  set_engine("lm") %>%
  set_mode("regression")

lr_workflow <- workflow() %>%
  add_recipe(diamonds_recipe) %>%
  add_model(lr_model)

lr_fitted_workflow <-  lr_workflow %>%
  last_fit(diamonds_split)

performance <- lr_fitted_workflow %>% collect_metrics()
test_predictions <- lr_fitted_workflow %>% collect_predictions()

final_model <- fit(lr_workflow, diamonds)

到这里一切似乎都正常,当我尝试使用预测功能时出现错误

我试过这个:

predict(final_model, new_data = bake(preprocesados, diamonds_test))

Error: The following required columns are missing: 'carat', 'cut', 'color', 'clarity'.
Traceback:

1. predict(final_model, new_data = bake(preprocesados, diamonds_test))
2. predict.workflow(final_model, new_data = bake(preprocesados, 
 .     diamonds_test))
3. hardhat::forge(new_data, blueprint)
4. forge.data.frame(new_data, blueprint)
5. blueprint$forge$clean(blueprint = blueprint, new_data = new_data, 
 .     outcomes = outcomes)
6. shrink(new_data, blueprint$ptypes$predictors)
7. validate_column_names(data, cols)
8. glubort("The following required columns are missing: {missing_names}.")
9. abort(glue(..., .sep = .sep, .envir = .envir))
10. signal_abort(cnd)

和这个:

new_diamond <- tribble(~carat, ~cut, ~color, ~clarity, ~depth, ~table, ~x, ~y, ~z,
                        0.23,   "Ideal",    "E",    "SI2",  61.5,   55, 3.95, 3.98, 2.43)

predict(final_model, new_data = bake(preprocesados, new_diamond))

Warning message:
“ There were 3 columns that were factors when the recipe was prepped:
 'cut', 'color', 'clarity'.
 This may cause errors when processing new data.”

Error: Assigned data `log(new_data[[col_names[i]]] + object$offset, base = object$base)` must be compatible with existing data.
✖ Existing data has 1 row.
✖ Assigned data has 0 rows.
ℹ Row updates require a list value. Do you need `list()` or `as.list()`?
Traceback:

1. predict(final_model, new_data = bake(preprocesados, new_diamond))
2. predict.workflow(final_model, new_data = bake(preprocesados, 
 .     new_diamond))
3. hardhat::forge(new_data, blueprint)
4. bake(preprocesados, new_diamond)
5. bake.recipe(preprocesados, new_diamond)
6. bake(object$steps[[i]], new_data = new_data)
7. bake.step_log(object$steps[[i]], new_data = new_data)
8. `[<-`(`*tmp*`, , col_names[i], value = numeric(0))
9. `[<-.tbl_df`(`*tmp*`, , col_names[i], value = numeric(0))
10. tbl_subassign(x, i, j, value, i_arg, j_arg, substitute(value))
...

任何帮助都会非常感激

4

1 回答 1

4

尽量不要将烘焙与工作流程混为一谈,并记住在使用 all_outcomes 时通常需要跳过步骤

library(tidymodels)
#> -- Attaching packages --------------------------------------------------------------------------------------------- tidymodels 0.1.1 --
#> v broom     0.7.0      v recipes   0.1.13
#> v dials     0.0.8      v rsample   0.0.7 
#> v dplyr     1.0.0      v tibble    3.0.3 
#> v ggplot2   3.3.2      v tidyr     1.1.0 
#> v infer     0.5.3      v tune      0.1.1 
#> v modeldata 0.0.2      v workflows 0.1.2 
#> v parsnip   0.1.2      v yardstick 0.0.7 
#> v purrr     0.3.4
#> -- Conflicts ------------------------------------------------------------------------------------------------ tidymodels_conflicts() --
#> x purrr::discard() masks scales::discard()
#> x dplyr::filter()  masks stats::filter()
#> x dplyr::lag()     masks stats::lag()
#> x recipes::step()  masks stats::step()
data("diamonds")
set.seed(234589)
diamonds_split <- initial_split(diamonds, prop = 4/5)

diamonds_train <- training(diamonds_split)
diamonds_test <- testing(diamonds_split)

diamonds_recipe <- 
  recipe(price ~ ., data = diamonds_train) %>%
  step_log(all_outcomes(),skip = T) %>%
  step_normalize(all_predictors(), -all_nominal()) %>%
  step_dummy(all_nominal()) %>%
  step_poly(carat, degree = 2)

preprocesados <- prep(diamonds_recipe)

lr_model <- 
  linear_reg()%>%
  set_engine("lm") %>%
  set_mode("regression")

lr_workflow <- workflow() %>%
  add_recipe(diamonds_recipe) %>%
  add_model(lr_model)

final_model <- fit(lr_workflow, diamonds)

predict(final_model, new_data = diamonds_test)
#> # A tibble: 10,787 x 1
#>    .pred
#>    <dbl>
#>  1  5.94
#>  2  5.91
#>  3  5.87
#>  4  6.23
#>  5  6.22
#>  6  6.29
#>  7  6.05
#>  8  6.08
#>  9  6.35
#> 10  6.04
#> # ... with 10,777 more rows

reprex 包(v0.3.0)于 2020 年 8 月 4 日创建

于 2020-08-04T03:46:33.867 回答