1

我在尝试将what_if分析应用于xgboost模型时遇到了困难。我能够为模型运行what_if分析,randomForest但是当我尝试为xgboost模型运行它时它会中断。

我的问题是,给定titanic数据集,我该如何制作what_if情节?我在代码中添加了注释,以显示代码何时对我不利。

我知道我对这new_xgb_observation部分做了一些不正确的事情,但是what_if(据我所知)需要一个单一的观察,所以我试图从dtest矩阵中提取一个单一的观察。

这是对我来说破坏的代码部分:

#### #### #### #### #### #### ####
# new observation -  which breaks
new_xgb_observation <- dtest[1, ]

# ceteris paribus - what_if analysis which breaks
what_if(xgb_explain, observation = new_xgb_observation,
        selected_variables = c("gender", "age", "fare", "sibsp"))
#### #### #### #### #### #### ####

然后我在它下面展示一个工作randomForest模型。

数据:

library(DALEX)
library(ceterisParibus)
library(xgboost)

data("titanic")
data <- titanic

# some quick data cleaning
data <- data %>% 
  select(-c(class, embarked, country)) %>% 
  mutate(
    gender = as.numeric(gender) - 1,
    survived = as.numeric(survived) -1
  )

# split into training and testing data
smp_size <- floor(0.75 * nrow(data))
train_ind <- sample(seq_len(nrow(data)), size = smp_size)

train <- data[train_ind, ]
test <- data[-train_ind, ]

X_train <- train %>% 
  select(-c(survived)) %>% 
  as.matrix()

Y_train <- train %>% 
  select(c(survived)) %>% 
  as.matrix()

X_test <- test %>% 
  select(-c(survived)) %>% 
  as.matrix()

Y_test <- test %>% 
  select(c(survived)) %>% 
  as.matrix()

# train and test as xgb.DMatrix for the XGBoost model
dtrain <- xgb.DMatrix(data = X_train, label = Y_train)
dtest <- xgb.DMatrix(data = X_test, label = Y_test)

# XGBoost parameters
params <- list(
  "eta" = 0.2,
  "max_depth" = 6,
  "objective"="binary:logistic",
  "eval_metric"= "auc",
  "set.seed" = 176
)

# run the XGBoost model
watchlist <- list("train" = dtrain)
nround = 40
xgb.model <- xgb.train(params, dtrain, nround, watchlist)

# DALEX model explanation
xgb_explain <- explain(xgb.model, data = X_train, label = Y_train)

#### #### #### #### #### #### ####
# new observation -  which breaks
new_xgb_observation <- dtest[1, ]

# ceteris paribus - what_if analysis which breaks
what_if(xgb_explain, observation = new_xgb_observation,
        selected_variables = c("gender", "age", "fare", "sibsp"))
#### #### #### #### #### #### ####

################## random Forest model #################

Random_Forest_Model <- randomForest::randomForest(factor(survived) ~., data = train, na.action = na.omit, ntree = 50, importance = TRUE)

# same as for the XGBoost model but this time remove na values
train_rf <- na.omit(train)
X_train_rf <- train_rf %>% 
  select(-c(survived))
Y_train_rf <- train_rf %>% 
  select(c(survived))

test_rf <- na.omit(test)
X_test_rf <- test_rf %>% 
  select(-c(survived))
Y_test_rf <- test_rf %>% 
  select(c(survived))

# DALEX model explanation
rf_explain <- explain(Random_Forest_Model, 
                      data = X_train_rf,
                      y = Y_train_rf)

# This time this works.
new_obs <- X_test_rf[1, ]

# So does this
wi_rf_model <- what_if(rf_explain, observation = new_obs,
                       selected_variables = c("gender", "age", "fare", "sibsp"))

# And this is what I ultimately want.
plot(wi_rf_model, split = "variables", color = "variables", quantiles = FALSE)
4

0 回答 0