0

当我运行以下数据时,它显示了不正确的 roc_curve。

准备

使用 r-studio 的任何人都应该可以运行以下代码。数据框包含不同员工的特征:绩效评级、销售数据以及他们是否被提升。

我正在尝试创建一个决策树模型,该模型使用所有其他变量来预测员工是否被提升。这个问题的主要目的是找出我在尝试使用 roc_curve() 函数时做错了什么。

library(tidyverse)
library(tidymodels)
library(peopleanalyticsdata)
    

url <- "http://peopleanalytics-regression-book.org/data/salespeople.csv"
    
   

salespeople <- read.csv(url)
    
    
salespeople <- salespeople %>% mutate(promoted = factor(ifelse(promoted == 1, "yes", "no")))
    

创建测试/训练数据

使用我自己自制的 train_test() 函数只是为了好玩!

    train_test <- function(data, train.size=0.7, na.rm=FALSE) {
      if(na.rm == TRUE) {
        dt <- sample(x=nrow(data), size=nrow(data)* train.size)
        data_nm <- na.omit(data)
        train<-data_nm[dt,]
        test<- data_nm[-dt,]
        set <- list(train, test)
        names(set) <- c("train", "test")
        return(set) 
      } else {
        dt <- sample(x=nrow(data), size=nrow(data)* train.size)
        train<-data[dt,]
        test<- data[-dt,]
        set <- list(train, test)
        names(set) <- c("train", "test")
        return(set)  
      }
    }
    
    tt_list <- train_test(salespeople)
    
    sales_train <- tt_list$train
    
    sales_test <- tt_list$test
    
  '''  

创建决策树模型结构/最终模型/预测数据框

'''    
tree <- decision_tree() %>%
          set_engine("rpart") %>%
          set_mode("classification") 


    model <- tree %>% fit(promoted ~ ., data = sales_train)
    
   

    predictions <- predict(model, 
                           sales_test,
                           type = "prob") %>% 
      bind_cols(sales_test)
    
'''    
   

计算并绘制 ROC 曲线

当我使用 .pred_yes 列作为估计列时,它会计算出与我想要的相反的 ROC 曲线。似乎它已将 .pred_no 确定为“真实”估计列

 '''

roc <- roc_curve(predictions, 
   estimate = .pred_yes, 
                         truth = promoted)
        
       

        autoplot(roc)

    '''

想法

当我将 pred_no 作为估计列提供给 roc_curve() 时,问题似乎消失了

仅供参考:这是我的第一个堆栈溢出帖子,如果您有任何建议使这篇文章更清晰/格式更好,请告诉我!

4

1 回答 1

0

factor(c("yes", "no"))中,“否”是第一个级别,大多数建模包假定的级别是感兴趣的级别。在 tidymodels 中,您可以通过event_level参数调整兴趣级别,如此处所述:

library(tidyverse)
library(tidymodels)
#> Registered S3 method overwritten by 'tune':
#>   method                   from   
#>   required_pkgs.model_spec parsnip

url <- "http://peopleanalytics-regression-book.org/data/salespeople.csv"
salespeople <- read_csv(url) %>% 
    mutate(promoted = factor(ifelse(promoted == 1, "yes", "no")))
#> Rows: 351 Columns: 4
#> ── Column specification ────────────────────────────────────────────────────────
#> Delimiter: ","
#> dbl (4): promoted, sales, customer_rate, performance
#> 
#> ℹ Use `spec()` to retrieve the full column specification for this data.
#> ℹ Specify the column types or set `show_col_types = FALSE` to quiet this message.
sales_split <- initial_split(salespeople)
sales_train <- training(sales_split)
sales_test <- testing(sales_split)

tree <- decision_tree() %>%
    set_engine("rpart") %>%
    set_mode("classification") 


tree_fit <- tree %>% fit(promoted ~ ., data = sales_train)
sales_preds <- augment(tree_fit, sales_test)
sales_preds
#> # A tibble: 88 × 7
#>    promoted sales customer_rate performance .pred_class .pred_no .pred_yes
#>    <fct>    <dbl>         <dbl>       <dbl> <fct>          <dbl>     <dbl>
#>  1 no         364          4.89           1 no             0.973    0.0267
#>  2 no         342          3.74           3 no             0.973    0.0267
#>  3 yes        716          3.16           3 yes            0        1     
#>  4 no         450          3.21           3 no             0.973    0.0267
#>  5 no         372          3.87           3 no             0.973    0.0267
#>  6 no         535          4.47           2 no             0.973    0.0267
#>  7 yes        736          3.94           4 yes            0        1     
#>  8 no         330          2.54           2 no             0.973    0.0267
#>  9 no         478          3.48           2 no             0.973    0.0267
#> 10 yes        728          2.66           3 yes            0        1     
#> # … with 78 more rows

sales_preds %>%
    roc_curve(promoted, .pred_yes, event_level = "second") %>%
    autoplot()

reprex 包(v2.0.1)于 2021 年 9 月 8 日创建

于 2021-09-08T21:58:00.153 回答