0

我正在尝试使用自定义预测器,所以我正在做一些基本测试来检查再现性和预测有效性。我尝试了一个非常简单的线性模型来构建我的自定义预测器,并将结果与​​具有默认预测器函数的相同线性模型进行比较。整个数据集的预测值与不同的预测器匹配,但如果我运行FeatImp我会得到非常不同的结果。任何想法?

library(iml)
library(dplyr)

set.seed(42)
data("Boston", package = "MASS")
mod <- lm(medv ~ ., data = Boston)
X <- Boston[which(names(Boston) != "medv")]
predictor <- Predictor$new(mod, data = X, y = Boston$medv, type = NULL, class = NULL)

#  Create custom predict function
custom_pred_fun <- function(model, newdata){
  results <- predict(model, data = newdata) %>% as.numeric()
  return(results)
}

# example of prediction output
custom_pred_fun(mod, X) %>% head()

# define custom predictor
custom_predictor <- Predictor$new(
  model = mod, 
  data = X, 
  y = Boston$medv, 
  predict.fun = custom_pred_fun,
  type = NULL,
  class = NULL
)

# check predictions
check_prediction = custom_predictor$predict(X) %>%
  rename(custom = pred) %>%
  bind_cols(predictor$predict(X)) %>%
  mutate(difference = abs(custom-pred))
range(check_prediction$difference)

# test feature importance
set.seed(42)
imp <- FeatureImp$new(predictor, loss = "mae", n.repetitions = 5, compare = "difference") 
imp$results
set.seed(42)
imp2 <- FeatureImp$new(custom_predictor, loss = "mae", n.repetitions = 5, compare = "difference") 
imp2$results

第一个输出是

> imp$results
   feature importance.05    importance importance.95 permutation.error
1    lstat  2.1493884252  2.1984705726  2.3359516132          5.469333
2      dis  1.4179569361  1.5727117378  1.6880731883          4.843575
3      rad  0.9891360131  1.1130780880  1.2701662081          4.383941
4       rm  0.8710001468  0.9639893200  1.1299396198          4.234852
5      nox  0.7931062607  0.8750023005  0.9870623982          4.145865
6      tax  0.8147508309  0.8417232862  1.0065056619          4.112586
7  ptratio  0.7251740844  0.8358012462  0.8561883645          4.106664
8       zn  0.1673241230  0.1999924241  0.2861123450          3.470855
9    black  0.1036532944  0.1777958771  0.2262358726          3.448659
10    crim  0.1551785270  0.1753026006  0.2175427460          3.446165
11    chas  0.0227621927  0.0573480436  0.0644136832          3.328211
12   indus -0.0007655180  0.0052707037  0.0102879482          3.276134
13     age -0.0008699595 -0.0006881205  0.0002567683          3.270175

第二个输出是

> imp2$results
   feature importance.05 importance importance.95 permutation.error
1       rm      5.740759   6.312130      6.491682          9.582993
2      rad      5.866348   6.216290      6.515221          9.487153
3     chas      6.032097   6.178665      6.312786          9.449528
4      dis      6.082575   6.172672      6.383786          9.443535
5  ptratio      6.044840   6.157038      6.326485          9.427901
6       zn      6.049599   6.135749      6.301742          9.406611
7    lstat      5.934761   6.133459      6.561748          9.404321
8     crim      6.061094   6.107696      6.241122          9.378559
9      age      5.562145   6.057008      6.312441          9.327871
10     nox      5.814618   6.010638      6.298920          9.281501
11     tax      5.940079   5.976696      6.357833          9.247559
12   black      5.861750   5.897589      6.191795          9.168452
13   indus      5.849306   5.872035      6.094271          9.142898

多谢

4

0 回答 0