我SMOTE {smotefamily}
的K
参数的 trafo 函数有问题。特别是,当最近邻的数量K
大于或等于样本大小时,返回错误(warning("k should be less than sample size!")
)并终止调整过程。
在内部重采样过程中,用户无法控制K
小于样本大小。这必须在内部进行控制,例如,如果trafo_K = 2 ^ K >= sample_size
对于 的某个值K
,则说trafo_K = sample_size - 1
。
我想知道是否有解决方案,或者是否已经在路上?
library("mlr3") # mlr3 base package
library("mlr3misc") # contains some helper functions
library("mlr3pipelines") # create ML pipelines
library("mlr3tuning") # tuning ML algorithms
library("mlr3learners") # additional ML algorithms
library("mlr3viz") # autoplot for benchmarks
library("paradox") # hyperparameter space
library("OpenML") # to obtain data sets
library("smotefamily") # SMOTE algorithm for imbalance correction
# get list of curated binary classification data sets (see https://arxiv.org/abs/1708.03731v2)
ds = listOMLDataSets(
number.of.classes = 2,
number.of.features = c(1, 100),
number.of.instances = c(5000, 10000)
)
# select imbalanced data sets (without categorical features as SMOTE cannot handle them)
ds = subset(ds, minority.class.size / number.of.instances < 0.2 &
number.of.symbolic.features == 1)
ds
d = getOMLDataSet(980)
d
# make sure target is a factor and create mlr3 tasks
data = as.data.frame(d)
data[[d$target.features]] = as.factor(data[[d$target.features]])
task = TaskClassif$new(
id = d$desc$name, backend = data,
target = d$target.features)
task
# Code above copied from https://mlr3gallery.mlr-org.com/posts/2020-03-30-imbalanced-data/
class_counts <- table(task$truth())
majority_to_minority_ratio <- class_counts[class_counts == max(class_counts)] /
class_counts[class_counts == min(class_counts)]
# Pipe operator for SMOTE
po_smote <- po("smote", dup_size = round(majority_to_minority_ratio))
# Random Forest learner
rf <- lrn("classif.ranger", predict_type = "prob")
# Pipeline of Random Forest learner with SMOTE
graph <- po_smote %>>%
po('learner', rf, id = 'rf')
graph$plot()
# Graph learner
rf_smote <- GraphLearner$new(graph, predict_type = 'prob')
rf_smote$predict_type <- 'prob'
# Parameter set in data table format
ps_table <- as.data.table(rf_smote$param_set)
View(ps_table[, 1:4])
# Define parameter search space for the SMOTE parameters
param_set <- ps_table$id %>%
lapply(
function(x) {
if (grepl('smote.', x)) {
if (grepl('.dup_size', x)) {
ParamInt$new(x, lower = 1, upper = round(majority_to_minority_ratio))
} else if (grepl('.K', x)) {
ParamInt$new(x, lower = 1, upper = round(majority_to_minority_ratio))
}
}
}
)
param_set <- Filter(Negate(is.null), param_set)
param_set <- ParamSet$new(param_set)
# Apply transformation function on SMOTE's K (= The number of nearest neighbors used for sampling new values. See SMOTE().)
param_set$trafo <- function(x, param_set) {
index <- which(grepl('.K', names(x)))
if (sum(index) != 0){
x[[index]] <- round(3 ^ x[[index]]) # Intentionally define a trafo that won't work
}
x
}
# Define and instantiate resampling strategy to be applied within pipeline
cv <- rsmp("cv", folds = 2)
cv$instantiate(task)
# Set up tuning instance
instance <- TuningInstance$new(
task = task,
learner = rf_smote,
resampling = cv,
measures = msr("classif.bbrier"),
param_set,
terminator = term("evals", n_evals = 3),
store_models = TRUE)
tuner <- TunerRandomSearch$new()
# Tune pipe learner to find optimal SMOTE parameter values
tuner$optimize(instance)
这就是发生的事情
INFO [11:00:14.904] Benchmark with 2 resampling iterations
INFO [11:00:14.919] Applying learner 'smote.rf' on task 'optdigits' (iter 2/2)
Error in get.knnx(data, query, k, algorithm) : ANN: ERROR------->
In addition: Warning message:
In get.knnx(data, query, k, algorithm) : k should be less than sample size!
会话信息
R version 3.6.2 (2019-12-12)
Platform: x86_64-w64-mingw32/x64 (64-bit)
Running under: Windows 10 x64 (build 16299)
Matrix products: default
locale:
[1] LC_COLLATE=English_United Kingdom.1252 LC_CTYPE=English_United Kingdom.1252
[3] LC_MONETARY=English_United Kingdom.1252 LC_NUMERIC=C
[5] LC_TIME=English_United Kingdom.1252
attached base packages:
[1] stats graphics grDevices utils datasets methods base
other attached packages:
[1] smotefamily_1.3.1 OpenML_1.10 mlr3viz_0.1.1.9002
[4] mlr3tuning_0.1.2-9000 mlr3pipelines_0.1.2.9000 mlr3misc_0.2.0
[7] mlr3learners_0.2.0 mlr3filters_0.2.0.9000 mlr3_0.2.0-9000
[10] paradox_0.2.0 yardstick_0.0.5 rsample_0.0.5
[13] recipes_0.1.9 parsnip_0.0.5 infer_0.5.1
[16] dials_0.0.4 scales_1.1.0 broom_0.5.4
[19] tidymodels_0.0.3 reshape2_1.4.3 janitor_1.2.1
[22] data.table_1.12.8 forcats_0.4.0 stringr_1.4.0
[25] dplyr_0.8.4 purrr_0.3.3 readr_1.3.1
[28] tidyr_1.0.2 tibble_3.0.1 ggplot2_3.3.0
[31] tidyverse_1.3.0
loaded via a namespace (and not attached):
[1] utf8_1.1.4 tidyselect_1.0.0 lme4_1.1-21
[4] htmlwidgets_1.5.1 grid_3.6.2 ranger_0.12.1
[7] pROC_1.16.1 munsell_0.5.0 codetools_0.2-16
[10] bbotk_0.1 DT_0.12 future_1.17.0
[13] miniUI_0.1.1.1 withr_2.2.0 colorspace_1.4-1
[16] knitr_1.28 uuid_0.1-4 rstudioapi_0.10
[19] stats4_3.6.2 bayesplot_1.7.1 listenv_0.8.0
[22] rstan_2.19.2 lgr_0.3.4 DiceDesign_1.8-1
[25] vctrs_0.2.4 generics_0.0.2 ipred_0.9-9
[28] xfun_0.12 R6_2.4.1 markdown_1.1
[31] mlr3measures_0.1.3-9000 rstanarm_2.19.2 lhs_1.0.1
[34] assertthat_0.2.1 promises_1.1.0 nnet_7.3-12
[37] gtable_0.3.0 globals_0.12.5 processx_3.4.1
[40] timeDate_3043.102 rlang_0.4.5 workflows_0.1.1
[43] BBmisc_1.11 splines_3.6.2 checkmate_2.0.0
[46] inline_0.3.15 yaml_2.2.1 modelr_0.1.5
[49] tidytext_0.2.2 threejs_0.3.3 crosstalk_1.0.0
[52] backports_1.1.6 httpuv_1.5.2 rsconnect_0.8.16
[55] tokenizers_0.2.1 tools_3.6.2 lava_1.6.6
[58] ellipsis_0.3.0 ggridges_0.5.2 Rcpp_1.0.4.6
[61] plyr_1.8.5 base64enc_0.1-3 visNetwork_2.0.9
[64] ps_1.3.0 prettyunits_1.1.1 rpart_4.1-15
[67] zoo_1.8-7 haven_2.2.0 fs_1.3.1
[70] furrr_0.1.0 magrittr_1.5 colourpicker_1.0
[73] reprex_0.3.0 GPfit_1.0-8 SnowballC_0.6.0
[76] packrat_0.5.0 matrixStats_0.55.0 tidyposterior_0.0.2
[79] hms_0.5.3 shinyjs_1.1 mime_0.8
[82] xtable_1.8-4 XML_3.99-0.3 tidypredict_0.4.3
[85] shinystan_2.5.0 readxl_1.3.1 gridExtra_2.3
[88] rstantools_2.0.0 compiler_3.6.2 crayon_1.3.4
[91] minqa_1.2.4 StanHeaders_2.21.0-1 htmltools_0.4.0
[94] later_1.0.0 lubridate_1.7.4 DBI_1.1.0
[97] dbplyr_1.4.2 MASS_7.3-51.4 boot_1.3-23
[100] Matrix_1.2-18 cli_2.0.1 parallel_3.6.2
[103] gower_0.2.1 igraph_1.2.4.2 pkgconfig_2.0.3
[106] xml2_1.2.2 foreach_1.4.7 dygraphs_1.1.1.6
[109] prodlim_2019.11.13 farff_1.1 rvest_0.3.5
[112] snakecase_0.11.0 janeaustenr_0.1.5 callr_3.4.1
[115] digest_0.6.25 cellranger_1.1.0 curl_4.3
[118] shiny_1.4.0 gtools_3.8.1 nloptr_1.2.1
[121] lifecycle_0.2.0 nlme_3.1-142 jsonlite_1.6.1
[124] fansi_0.4.1 pillar_1.4.3 lattice_0.20-38
[127] loo_2.2.0 fastmap_1.0.1 httr_1.4.1
[130] pkgbuild_1.0.6 survival_3.1-8 glue_1.4.0
[133] xts_0.12-0 FNN_1.1.3 shinythemes_1.1.2
[136] iterators_1.0.12 class_7.3-15 stringi_1.4.4
[139] memoise_1.1.0 future.apply_1.5.0
非常感谢。