1

SMOTE {smotefamily}K参数的 trafo 函数有问题。特别是,当最近邻的数量K大于或等于样本大小时,返回错误(warning("k should be less than sample size!"))并终止调整过程。

在内部重采样过程中,用户无法控制K小于样本大小。这必须在内部进行控制,例如,如果trafo_K = 2 ^ K >= sample_size对于 的某个值K,则说trafo_K = sample_size - 1

我想知道是否有解决方案,或者是否已经在路上?

library("mlr3") # mlr3 base package
library("mlr3misc") # contains some helper functions
library("mlr3pipelines") # create ML pipelines
library("mlr3tuning") # tuning ML algorithms
library("mlr3learners") # additional ML algorithms
library("mlr3viz") # autoplot for benchmarks
library("paradox") # hyperparameter space
library("OpenML") # to obtain data sets
library("smotefamily") # SMOTE algorithm for imbalance correction

# get list of curated binary classification data sets (see https://arxiv.org/abs/1708.03731v2)
ds = listOMLDataSets(
  number.of.classes = 2,
  number.of.features = c(1, 100),
  number.of.instances = c(5000, 10000)
)
# select imbalanced data sets (without categorical features as SMOTE cannot handle them)
ds = subset(ds, minority.class.size / number.of.instances < 0.2 &
              number.of.symbolic.features == 1)
ds

d = getOMLDataSet(980)
d

# make sure target is a factor and create mlr3 tasks
data = as.data.frame(d)
data[[d$target.features]] = as.factor(data[[d$target.features]])
task = TaskClassif$new(
  id = d$desc$name, backend = data,
  target = d$target.features)
task

# Code above copied from https://mlr3gallery.mlr-org.com/posts/2020-03-30-imbalanced-data/

class_counts <- table(task$truth())
majority_to_minority_ratio <- class_counts[class_counts == max(class_counts)] / 
  class_counts[class_counts == min(class_counts)]

# Pipe operator for SMOTE
po_smote <- po("smote", dup_size = round(majority_to_minority_ratio))

# Random Forest learner
rf <- lrn("classif.ranger", predict_type = "prob")

# Pipeline of Random Forest learner with SMOTE
graph <- po_smote %>>%
  po('learner', rf, id = 'rf')
graph$plot()

# Graph learner
rf_smote <- GraphLearner$new(graph, predict_type = 'prob')
rf_smote$predict_type <- 'prob'

# Parameter set in data table format
ps_table <- as.data.table(rf_smote$param_set)
View(ps_table[, 1:4])

# Define parameter search space for the SMOTE parameters
param_set <- ps_table$id %>%
  lapply(
    function(x) {
      if (grepl('smote.', x)) {
        if (grepl('.dup_size', x)) {
          ParamInt$new(x, lower = 1, upper = round(majority_to_minority_ratio))
        } else if (grepl('.K', x)) {
          ParamInt$new(x, lower = 1, upper = round(majority_to_minority_ratio))
        }
      }
    }
  )
param_set <- Filter(Negate(is.null), param_set)
param_set <- ParamSet$new(param_set)

# Apply transformation function on SMOTE's K (= The number of nearest neighbors used for sampling new values. See SMOTE().)
param_set$trafo <- function(x, param_set) {
  index <- which(grepl('.K', names(x)))
  if (sum(index) != 0){
    x[[index]] <- round(3 ^ x[[index]]) #  Intentionally define a trafo that won't work
  }
  x
}

# Define and instantiate resampling strategy to be applied within pipeline
cv <- rsmp("cv", folds = 2)
cv$instantiate(task)

# Set up tuning instance
instance <- TuningInstance$new(
  task = task,
  learner = rf_smote,
  resampling = cv,
  measures = msr("classif.bbrier"),
  param_set,
  terminator = term("evals", n_evals = 3), 
  store_models = TRUE)
tuner <- TunerRandomSearch$new()

# Tune pipe learner to find optimal SMOTE parameter values
tuner$optimize(instance)

这就是发生的事情

INFO  [11:00:14.904] Benchmark with 2 resampling iterations 
INFO  [11:00:14.919] Applying learner 'smote.rf' on task 'optdigits' (iter 2/2) 
Error in get.knnx(data, query, k, algorithm) : ANN: ERROR------->
In addition: Warning message:
In get.knnx(data, query, k, algorithm) : k should be less than sample size!

会话信息

R version 3.6.2 (2019-12-12)
Platform: x86_64-w64-mingw32/x64 (64-bit)
Running under: Windows 10 x64 (build 16299)

Matrix products: default

locale:
[1] LC_COLLATE=English_United Kingdom.1252  LC_CTYPE=English_United Kingdom.1252   
[3] LC_MONETARY=English_United Kingdom.1252 LC_NUMERIC=C                           
[5] LC_TIME=English_United Kingdom.1252    

attached base packages:
[1] stats     graphics  grDevices utils     datasets  methods   base     

other attached packages:
 [1] smotefamily_1.3.1        OpenML_1.10              mlr3viz_0.1.1.9002      
 [4] mlr3tuning_0.1.2-9000    mlr3pipelines_0.1.2.9000 mlr3misc_0.2.0          
 [7] mlr3learners_0.2.0       mlr3filters_0.2.0.9000   mlr3_0.2.0-9000         
[10] paradox_0.2.0            yardstick_0.0.5          rsample_0.0.5           
[13] recipes_0.1.9            parsnip_0.0.5            infer_0.5.1             
[16] dials_0.0.4              scales_1.1.0             broom_0.5.4             
[19] tidymodels_0.0.3         reshape2_1.4.3           janitor_1.2.1           
[22] data.table_1.12.8        forcats_0.4.0            stringr_1.4.0           
[25] dplyr_0.8.4              purrr_0.3.3              readr_1.3.1             
[28] tidyr_1.0.2              tibble_3.0.1             ggplot2_3.3.0           
[31] tidyverse_1.3.0         

loaded via a namespace (and not attached):
  [1] utf8_1.1.4              tidyselect_1.0.0        lme4_1.1-21            
  [4] htmlwidgets_1.5.1       grid_3.6.2              ranger_0.12.1          
  [7] pROC_1.16.1             munsell_0.5.0           codetools_0.2-16       
 [10] bbotk_0.1               DT_0.12                 future_1.17.0          
 [13] miniUI_0.1.1.1          withr_2.2.0             colorspace_1.4-1       
 [16] knitr_1.28              uuid_0.1-4              rstudioapi_0.10        
 [19] stats4_3.6.2            bayesplot_1.7.1         listenv_0.8.0          
 [22] rstan_2.19.2            lgr_0.3.4               DiceDesign_1.8-1       
 [25] vctrs_0.2.4             generics_0.0.2          ipred_0.9-9            
 [28] xfun_0.12               R6_2.4.1                markdown_1.1           
 [31] mlr3measures_0.1.3-9000 rstanarm_2.19.2         lhs_1.0.1              
 [34] assertthat_0.2.1        promises_1.1.0          nnet_7.3-12            
 [37] gtable_0.3.0            globals_0.12.5          processx_3.4.1         
 [40] timeDate_3043.102       rlang_0.4.5             workflows_0.1.1        
 [43] BBmisc_1.11             splines_3.6.2           checkmate_2.0.0        
 [46] inline_0.3.15           yaml_2.2.1              modelr_0.1.5           
 [49] tidytext_0.2.2          threejs_0.3.3           crosstalk_1.0.0        
 [52] backports_1.1.6         httpuv_1.5.2            rsconnect_0.8.16       
 [55] tokenizers_0.2.1        tools_3.6.2             lava_1.6.6             
 [58] ellipsis_0.3.0          ggridges_0.5.2          Rcpp_1.0.4.6           
 [61] plyr_1.8.5              base64enc_0.1-3         visNetwork_2.0.9       
 [64] ps_1.3.0                prettyunits_1.1.1       rpart_4.1-15           
 [67] zoo_1.8-7               haven_2.2.0             fs_1.3.1               
 [70] furrr_0.1.0             magrittr_1.5            colourpicker_1.0       
 [73] reprex_0.3.0            GPfit_1.0-8             SnowballC_0.6.0        
 [76] packrat_0.5.0           matrixStats_0.55.0      tidyposterior_0.0.2    
 [79] hms_0.5.3               shinyjs_1.1             mime_0.8               
 [82] xtable_1.8-4            XML_3.99-0.3            tidypredict_0.4.3      
 [85] shinystan_2.5.0         readxl_1.3.1            gridExtra_2.3          
 [88] rstantools_2.0.0        compiler_3.6.2          crayon_1.3.4           
 [91] minqa_1.2.4             StanHeaders_2.21.0-1    htmltools_0.4.0        
 [94] later_1.0.0             lubridate_1.7.4         DBI_1.1.0              
 [97] dbplyr_1.4.2            MASS_7.3-51.4           boot_1.3-23            
[100] Matrix_1.2-18           cli_2.0.1               parallel_3.6.2         
[103] gower_0.2.1             igraph_1.2.4.2          pkgconfig_2.0.3        
[106] xml2_1.2.2              foreach_1.4.7           dygraphs_1.1.1.6       
[109] prodlim_2019.11.13      farff_1.1               rvest_0.3.5            
[112] snakecase_0.11.0        janeaustenr_0.1.5       callr_3.4.1            
[115] digest_0.6.25           cellranger_1.1.0        curl_4.3               
[118] shiny_1.4.0             gtools_3.8.1            nloptr_1.2.1           
[121] lifecycle_0.2.0         nlme_3.1-142            jsonlite_1.6.1         
[124] fansi_0.4.1             pillar_1.4.3            lattice_0.20-38        
[127] loo_2.2.0               fastmap_1.0.1           httr_1.4.1             
[130] pkgbuild_1.0.6          survival_3.1-8          glue_1.4.0             
[133] xts_0.12-0              FNN_1.1.3               shinythemes_1.1.2      
[136] iterators_1.0.12        class_7.3-15            stringi_1.4.4          
[139] memoise_1.1.0           future.apply_1.5.0     

非常感谢。

4

1 回答 1

0

我找到了解决方法。

如前所述,问题在于SMOTE {smotefamily}'K不能大于或等于样本大小。

我深入研究了该过程并发现了SMOTE {smotefamily}使用knearest {smotefamily},使用knnx.index {FNN},又使用get.knn {FNN},这就是返回错误的原因,该错误warning("k should be less than sample size!")终止了调整过程mlr3

现在,在 内,SMOTE {smotefamily}三个参数knearest {smotefamily}P_set和。从重采样的角度来看,数据框是训练数据交叉验证折叠的子集,过滤后仅包含少数类的记录。错误所指的“样本大小”是.P_setKmlr3P_setP_set

因此,它变得更有可能通过诸如K >= nrow(P_set)(例如)的交通而增加。Ksome_integer ^ K2 ^ K

我们需要确保K永远不会大于或等于P_set

这是我提出的解决方案:

  1. 在定义 CV 重采样策略cv_folds 之前定义一个变量rsmp()
  2. folds = cv_folds在定义trafo 之前rsmp()定义CV 重采样策略。
  3. 实例化 CV。现在,数据集在每一折中分为训练和测试/验证数据。
  4. 在所有训练数据折叠中找到少数类的最小样本大小,并将其设置为 的阈值K
smote_k_thresh <- 1:cv_folds %>%
  lapply(
    function(x) {
      index <- cv$train_set(x)
      aux <- as.data.frame(task$data())[index, task$target_names]
      aux <- min(table(aux))
    }
  ) %>%
  bind_cols %>%
  min %>%
  unique
  1. 现在定义 trafo 如下:
param_set$trafo <- function(x, param_set) {
  index <- which(grepl('.K', names(x)))
  if (sum(index) != 0){
    aux <- round(2 ^ x[[index]])
    if (aux < smote_k_thresh) {
      x[[index]] <- aux
    } else {
      x[[index]] <- sample(smote_k_thresh - 1, 1)
    }
  }
  x
}

换句话说,当 trafoedK仍然小于样本大小时,保留它。否则,将其值设置为 1 到 之间的任意数字smote_k_thresh - 1

执行

原始代码稍作修改以适应建议的调整:

library("mlr3learners") # additional ML algorithms
library("mlr3viz") # autoplot for benchmarks
library("paradox") # hyperparameter space
library("OpenML") # to obtain data sets
library("smotefamily") # SMOTE algorithm for imbalance correction

# get list of curated binary classification data sets (see https://arxiv.org/abs/1708.03731v2)
ds = listOMLDataSets(
  number.of.classes = 2,
  number.of.features = c(1, 100),
  number.of.instances = c(5000, 10000)
)
# select imbalanced data sets (without categorical features as SMOTE cannot handle them)
ds = subset(ds, minority.class.size / number.of.instances < 0.2 &
              number.of.symbolic.features == 1)
ds

d = getOMLDataSet(980)
d

# make sure target is a factor and create mlr3 tasks
data = as.data.frame(d)
data[[d$target.features]] = as.factor(data[[d$target.features]])
task = TaskClassif$new(
  id = d$desc$name, backend = data,
  target = d$target.features)
task

# Code above copied from https://mlr3gallery.mlr-org.com/posts/2020-03-30-imbalanced-data/

class_counts <- table(task$truth())
majority_to_minority_ratio <- class_counts[class_counts == max(class_counts)] / 
  class_counts[class_counts == min(class_counts)]

# Pipe operator for SMOTE
po_smote <- po("smote", dup_size = round(majority_to_minority_ratio))

# Define and instantiate resampling strategy to be applied within pipeline
# Do that BEFORE defining the trafo
cv_folds <- 2
cv <- rsmp("cv", folds = cv_folds)
cv$instantiate(task)

# Calculate max possible value for k-nearest neighbours
smote_k_thresh <- 1:cv_folds %>%
  lapply(
    function(x) {
      index <- cv$train_set(x)
      aux <- as.data.frame(task$data())[index, task$target_names]
      aux <- min(table(aux))
    }
  ) %>%
  bind_cols %>%
  min %>%
  unique

# Random Forest learner
rf <- lrn("classif.ranger", predict_type = "prob")

# Pipeline of Random Forest learner with SMOTE
graph <- po_smote %>>%
  po('learner', rf, id = 'rf')
graph$plot()

# Graph learner
rf_smote <- GraphLearner$new(graph, predict_type = 'prob')
rf_smote$predict_type <- 'prob'

# Parameter set in data table format
ps_table <- as.data.table(rf_smote$param_set)
View(ps_table[, 1:4])

# Define parameter search space for the SMOTE parameters
param_set <- ps_table$id %>%
  lapply(
    function(x) {
      if (grepl('smote.', x)) {
        if (grepl('.dup_size', x)) {
          ParamInt$new(x, lower = 1, upper = round(majority_to_minority_ratio))
        } else if (grepl('.K', x)) {
          ParamInt$new(x, lower = 1, upper = round(majority_to_minority_ratio))
        }
      }
    }
  )
param_set <- Filter(Negate(is.null), param_set)
param_set <- ParamSet$new(param_set)

# Apply transformation function on SMOTE's K while ensuring it never equals or exceeds the sample size
param_set$trafo <- function(x, param_set) {
  index <- which(grepl('.K', names(x)))
  if (sum(index) != 0){
    aux <- round(5 ^ x[[index]]) # Try a large value here for the sake of the example
    if (aux < smote_k_thresh) {
      x[[index]] <- aux
    } else {
      x[[index]] <- sample(smote_k_thresh - 1, 1)
    }
  }
  x
}

# Set up tuning instance
instance <- TuningInstance$new(
  task = task,
  learner = rf_smote,
  resampling = cv,
  measures = msr("classif.bbrier"),
  param_set,
  terminator = term("evals", n_evals = 10), 
  store_models = TRUE)
tuner <- TunerRandomSearch$new()

# Tune pipe learner to find optimal SMOTE parameter values
tuner$optimize(instance)

# Here are the original K values
instance$archive$data

# And here are their transformations
instance$archive$data$opt_x
于 2020-05-15T11:38:14.320 回答