2

我正在尝试使用 SHAP 库获取高斯过程回归 (GPR) 模型的 SHAP 值。但是,所有 SHAP 值都为零。我正在使用官方文档中的示例。我只是将模型更改为 GPR。

import sklearn
from sklearn.model_selection import train_test_split
import numpy as np
import shap
import time
from sklearn.gaussian_process import GaussianProcessRegressor
from sklearn.gaussian_process.kernels import Matern, WhiteKernel, ConstantKernel

shap.initjs()

X,y = shap.datasets.diabetes()
X_train,X_test,y_train,y_test = train_test_split(X, y, test_size=0.2, random_state=0)

# rather than use the whole training set to estimate expected values, we summarize with
# a set of weighted kmeans, each weighted by the number of points they represent.
X_train_summary = shap.kmeans(X_train, 10)


kernel = Matern(length_scale=2, nu=3/2) + WhiteKernel(noise_level=1)   

gp = GaussianProcessRegressor(kernel)
gp.fit(X_train, y_train)

# explain all the predictions in the test set
explainer = shap.KernelExplainer(gp.predict, X_train_summary)
shap_values = explainer.shap_values(X_test)
shap.summary_plot(shap_values, X_test)

运行上面的代码会得到以下图:

在此处输入图像描述

当我使用神经网络或线性回归时,上面的代码可以正常工作。
如果您知道如何解决此问题,请告诉我。

4

1 回答 1

1

您的模型无法预测任何内容:

plt.scatter(y_test, gp.predict(X_test));

在此处输入图像描述

正确训练您的模型,如下所示:

plt.scatter(y_test, gp.predict(X_test));

在此处输入图像描述

你可以去:

explainer = shap.KernelExplainer(gp.predict, X_train_summary)
shap_values = explainer.shap_values(X_test)
shap.summary_plot(shap_values, X_test)

在此处输入图像描述

完全可重现的例子

import sklearn
from sklearn.model_selection import train_test_split
import numpy as np
import shap
import time
from sklearn.gaussian_process import GaussianProcessRegressor
from sklearn.gaussian_process.kernels import WhiteKernel, DotProduct

X,y = shap.datasets.diabetes()
X_train,X_test,y_train,y_test = train_test_split(X, y, test_size=0.2, random_state=0)
X_train_summary = shap.kmeans(X_train, 10)
kernel = DotProduct() + WhiteKernel()

gp = GaussianProcessRegressor(kernel)
gp.fit(X_train, y_train)

explainer = shap.KernelExplainer(gp.predict, X_train_summary)
shap_values = explainer.shap_values(X_test)
shap.summary_plot(shap_values, X_test)
于 2021-03-03T14:41:47.020 回答