在 mlr3 中创建过滤器时,如何仅将过滤器基于训练数据?
创建过滤器后,如何将过滤器应用于建模过程并将训练数据子集化为仅包含高于某个阈值的过滤器值?
library(mlr3)
library(mlr3filters)
library(mlr3learners)
library(tidyverse)
data(iris)
iris <- iris %>%
select(-Species)
tsk <- mlr3::TaskRegr$new("iris",
backend = iris,
target = "Sepal.Length")
#split train and test
trn_ids <- sample(tsk$row_ids, floor(0.8 * length(tsk$row_ids)), F)
tst_ids <- setdiff(tsk$row_ids, trn_ids)
#create a filter
filter = flt("correlation", method = "spearman")
# Question 1: how to calculate the filter only for the train IDs?
filter$calculate(tsk)
print(filter)
# Question 2: how to only use only variables with X correlation or greater in training?
learner <- mlr_learners$get("regr.glmnet")
learner$train(tsk, row_ids = trn_ids)
prediction <- learner$predict(tsk, row_ids = tst_ids)
prediction$response