嗨,我正在尝试使用 tidymodels 制作线性回归模型的示例,我设法使用框架正确拟合模型,并使用 collect_metrics() 和 collect_predictions() 在工作流中对其进行测试。但是,当我尝试使用该模型对新数据进行预测时,我无法让它发挥作用。我正在尝试调整这个例子:
rf_wflow_final_fit <- fit(rf_wflow_final, data = dia_train)
dia_rec3 <- pull_workflow_prepped_recipe(rf_wflow_final_fit)
rf_final_fit <- pull_workflow_fit(rf_wflow_final_fit)
dia_test$.pred <- predict(rf_final_fit,
new_data = bake(dia_rec3, dia_test))$.pred
dia_test$logprice <- log(dia_test$price)
metrics(dia_test, truth = logprice, estimate = .pred)
#> # A tibble: 3 x 3
#> .metric .estimator .estimate
#> <chr> <chr> <dbl>
#> 1 rmse standard 0.113
#> 2 rsq standard 0.988
#> 3 mae standard 0.0846
这就是我正在做的事情:
data("diamonds")
set.seed(234589)
diamonds_split <- initial_split(diamonds, prop = 4/5)
diamonds_train <- training(diamonds_split)
diamonds_test <- testing(diamonds_split)
diamonds_recipe <-
recipe(price ~ ., data = diamonds_train) %>%
step_log(all_outcomes()) %>%
step_normalize(all_predictors(), -all_nominal()) %>%
step_dummy(all_nominal()) %>%
step_poly(carat, degree = 2)
preprocesados <- prep(diamonds_recipe)
lr_model <-
linear_reg()%>%
set_engine("lm") %>%
set_mode("regression")
lr_workflow <- workflow() %>%
add_recipe(diamonds_recipe) %>%
add_model(lr_model)
lr_fitted_workflow <- lr_workflow %>%
last_fit(diamonds_split)
performance <- lr_fitted_workflow %>% collect_metrics()
test_predictions <- lr_fitted_workflow %>% collect_predictions()
final_model <- fit(lr_workflow, diamonds)
到这里一切似乎都正常,当我尝试使用预测功能时出现错误
我试过这个:
predict(final_model, new_data = bake(preprocesados, diamonds_test))
Error: The following required columns are missing: 'carat', 'cut', 'color', 'clarity'.
Traceback:
1. predict(final_model, new_data = bake(preprocesados, diamonds_test))
2. predict.workflow(final_model, new_data = bake(preprocesados,
. diamonds_test))
3. hardhat::forge(new_data, blueprint)
4. forge.data.frame(new_data, blueprint)
5. blueprint$forge$clean(blueprint = blueprint, new_data = new_data,
. outcomes = outcomes)
6. shrink(new_data, blueprint$ptypes$predictors)
7. validate_column_names(data, cols)
8. glubort("The following required columns are missing: {missing_names}.")
9. abort(glue(..., .sep = .sep, .envir = .envir))
10. signal_abort(cnd)
和这个:
new_diamond <- tribble(~carat, ~cut, ~color, ~clarity, ~depth, ~table, ~x, ~y, ~z,
0.23, "Ideal", "E", "SI2", 61.5, 55, 3.95, 3.98, 2.43)
predict(final_model, new_data = bake(preprocesados, new_diamond))
Warning message:
“ There were 3 columns that were factors when the recipe was prepped:
'cut', 'color', 'clarity'.
This may cause errors when processing new data.”
Error: Assigned data `log(new_data[[col_names[i]]] + object$offset, base = object$base)` must be compatible with existing data.
✖ Existing data has 1 row.
✖ Assigned data has 0 rows.
ℹ Row updates require a list value. Do you need `list()` or `as.list()`?
Traceback:
1. predict(final_model, new_data = bake(preprocesados, new_diamond))
2. predict.workflow(final_model, new_data = bake(preprocesados,
. new_diamond))
3. hardhat::forge(new_data, blueprint)
4. bake(preprocesados, new_diamond)
5. bake.recipe(preprocesados, new_diamond)
6. bake(object$steps[[i]], new_data = new_data)
7. bake.step_log(object$steps[[i]], new_data = new_data)
8. `[<-`(`*tmp*`, , col_names[i], value = numeric(0))
9. `[<-.tbl_df`(`*tmp*`, , col_names[i], value = numeric(0))
10. tbl_subassign(x, i, j, value, i_arg, j_arg, substitute(value))
...
任何帮助都会非常感激