我在尝试将what_if
分析应用于xgboost
模型时遇到了困难。我能够为模型运行what_if
分析,randomForest
但是当我尝试为xgboost
模型运行它时它会中断。
我的问题是,给定titanic
数据集,我该如何制作what_if
情节?我在代码中添加了注释,以显示代码何时对我不利。
我知道我对这new_xgb_observation
部分做了一些不正确的事情,但是what_if
(据我所知)需要一个单一的观察,所以我试图从dtest
矩阵中提取一个单一的观察。
这是对我来说破坏的代码部分:
#### #### #### #### #### #### ####
# new observation - which breaks
new_xgb_observation <- dtest[1, ]
# ceteris paribus - what_if analysis which breaks
what_if(xgb_explain, observation = new_xgb_observation,
selected_variables = c("gender", "age", "fare", "sibsp"))
#### #### #### #### #### #### ####
然后我在它下面展示一个工作randomForest
模型。
数据:
library(DALEX)
library(ceterisParibus)
library(xgboost)
data("titanic")
data <- titanic
# some quick data cleaning
data <- data %>%
select(-c(class, embarked, country)) %>%
mutate(
gender = as.numeric(gender) - 1,
survived = as.numeric(survived) -1
)
# split into training and testing data
smp_size <- floor(0.75 * nrow(data))
train_ind <- sample(seq_len(nrow(data)), size = smp_size)
train <- data[train_ind, ]
test <- data[-train_ind, ]
X_train <- train %>%
select(-c(survived)) %>%
as.matrix()
Y_train <- train %>%
select(c(survived)) %>%
as.matrix()
X_test <- test %>%
select(-c(survived)) %>%
as.matrix()
Y_test <- test %>%
select(c(survived)) %>%
as.matrix()
# train and test as xgb.DMatrix for the XGBoost model
dtrain <- xgb.DMatrix(data = X_train, label = Y_train)
dtest <- xgb.DMatrix(data = X_test, label = Y_test)
# XGBoost parameters
params <- list(
"eta" = 0.2,
"max_depth" = 6,
"objective"="binary:logistic",
"eval_metric"= "auc",
"set.seed" = 176
)
# run the XGBoost model
watchlist <- list("train" = dtrain)
nround = 40
xgb.model <- xgb.train(params, dtrain, nround, watchlist)
# DALEX model explanation
xgb_explain <- explain(xgb.model, data = X_train, label = Y_train)
#### #### #### #### #### #### ####
# new observation - which breaks
new_xgb_observation <- dtest[1, ]
# ceteris paribus - what_if analysis which breaks
what_if(xgb_explain, observation = new_xgb_observation,
selected_variables = c("gender", "age", "fare", "sibsp"))
#### #### #### #### #### #### ####
################## random Forest model #################
Random_Forest_Model <- randomForest::randomForest(factor(survived) ~., data = train, na.action = na.omit, ntree = 50, importance = TRUE)
# same as for the XGBoost model but this time remove na values
train_rf <- na.omit(train)
X_train_rf <- train_rf %>%
select(-c(survived))
Y_train_rf <- train_rf %>%
select(c(survived))
test_rf <- na.omit(test)
X_test_rf <- test_rf %>%
select(-c(survived))
Y_test_rf <- test_rf %>%
select(c(survived))
# DALEX model explanation
rf_explain <- explain(Random_Forest_Model,
data = X_train_rf,
y = Y_train_rf)
# This time this works.
new_obs <- X_test_rf[1, ]
# So does this
wi_rf_model <- what_if(rf_explain, observation = new_obs,
selected_variables = c("gender", "age", "fare", "sibsp"))
# And this is what I ultimately want.
plot(wi_rf_model, split = "variables", color = "variables", quantiles = FALSE)