我正在尝试使用自定义预测器,所以我正在做一些基本测试来检查再现性和预测有效性。我尝试了一个非常简单的线性模型来构建我的自定义预测器,并将结果与具有默认预测器函数的相同线性模型进行比较。整个数据集的预测值与不同的预测器匹配,但如果我运行FeatImp
我会得到非常不同的结果。任何想法?
library(iml)
library(dplyr)
set.seed(42)
data("Boston", package = "MASS")
mod <- lm(medv ~ ., data = Boston)
X <- Boston[which(names(Boston) != "medv")]
predictor <- Predictor$new(mod, data = X, y = Boston$medv, type = NULL, class = NULL)
# Create custom predict function
custom_pred_fun <- function(model, newdata){
results <- predict(model, data = newdata) %>% as.numeric()
return(results)
}
# example of prediction output
custom_pred_fun(mod, X) %>% head()
# define custom predictor
custom_predictor <- Predictor$new(
model = mod,
data = X,
y = Boston$medv,
predict.fun = custom_pred_fun,
type = NULL,
class = NULL
)
# check predictions
check_prediction = custom_predictor$predict(X) %>%
rename(custom = pred) %>%
bind_cols(predictor$predict(X)) %>%
mutate(difference = abs(custom-pred))
range(check_prediction$difference)
# test feature importance
set.seed(42)
imp <- FeatureImp$new(predictor, loss = "mae", n.repetitions = 5, compare = "difference")
imp$results
set.seed(42)
imp2 <- FeatureImp$new(custom_predictor, loss = "mae", n.repetitions = 5, compare = "difference")
imp2$results
第一个输出是
> imp$results
feature importance.05 importance importance.95 permutation.error
1 lstat 2.1493884252 2.1984705726 2.3359516132 5.469333
2 dis 1.4179569361 1.5727117378 1.6880731883 4.843575
3 rad 0.9891360131 1.1130780880 1.2701662081 4.383941
4 rm 0.8710001468 0.9639893200 1.1299396198 4.234852
5 nox 0.7931062607 0.8750023005 0.9870623982 4.145865
6 tax 0.8147508309 0.8417232862 1.0065056619 4.112586
7 ptratio 0.7251740844 0.8358012462 0.8561883645 4.106664
8 zn 0.1673241230 0.1999924241 0.2861123450 3.470855
9 black 0.1036532944 0.1777958771 0.2262358726 3.448659
10 crim 0.1551785270 0.1753026006 0.2175427460 3.446165
11 chas 0.0227621927 0.0573480436 0.0644136832 3.328211
12 indus -0.0007655180 0.0052707037 0.0102879482 3.276134
13 age -0.0008699595 -0.0006881205 0.0002567683 3.270175
第二个输出是
> imp2$results
feature importance.05 importance importance.95 permutation.error
1 rm 5.740759 6.312130 6.491682 9.582993
2 rad 5.866348 6.216290 6.515221 9.487153
3 chas 6.032097 6.178665 6.312786 9.449528
4 dis 6.082575 6.172672 6.383786 9.443535
5 ptratio 6.044840 6.157038 6.326485 9.427901
6 zn 6.049599 6.135749 6.301742 9.406611
7 lstat 5.934761 6.133459 6.561748 9.404321
8 crim 6.061094 6.107696 6.241122 9.378559
9 age 5.562145 6.057008 6.312441 9.327871
10 nox 5.814618 6.010638 6.298920 9.281501
11 tax 5.940079 5.976696 6.357833 9.247559
12 black 5.861750 5.897589 6.191795 9.168452
13 indus 5.849306 5.872035 6.094271 9.142898
多谢