我想计算线性模型的 SHAP 值。对于回归,我必须使用样本权重。
问题是我无法评估样本权重是否实际应用于正确计算 SHAP 值。
这是一个例子。
# Import libraries
import shap
import pandas
import numpy
from sklearn.linear_model import LinearRegression
from sklearn.metrics import r2_score
# Setting up the data and the model
df.head()
weights Funnel Q23_1 Q23_2 Q23_3 Q23_4 Q23_5 Q23_6 Q23_7
Q23_8 Q23_9 Q23_10 Q23_11 Q23_12 Q23_13 Q23_14 Q23_15
847 0.75149 5.0 2.0 2.0 1.0 3.0 3.0 3.0 2.0 5.0 3.0 1.0 2.0 2.0 2.0 3.0 1.0
995 2.18378 2.0 1.0 1.0 1.0 3.0 3.0 3.0 1.0 4.0 2.0 2.0 1.0 2.0 2.0 2.0 2.0
14403 1.10852 2.0 1.0 1.0 1.0 2.0 2.0 4.0 1.0 5.0 1.0 2.0 2.0 1.0 3.0 3.0
1.0
13311 0.85934 4.0 2.0 2.0 3.0 3.0 2.0 3.0 3.0 4.0 4.0 3.0 2.0 2.0 3.0 3.0
2.0
17019 0.95337 2.0 1.0 1.0 2.0 3.0 2.0 2.0 3.0 2.0 2.0 2.0 3.0 1.0 1.0 1.0
2.0
Y = df_t.drop(['Funnel', 'weights'], axis=1)
X = df_t[['Funnel']]
lm = LinearRegression()
首先,我计算没有权重的回归。
fit = lm.fit(X,Y)
pred = fit.predict(X)
print("R2 - No Weights:", r2_score(Y,pred))
然后我用权重计算回归。
fit = lm.fit(X,Y, sample_weight=df['weights'])
pred = fit.predict(X)
print("R2 - Wit weights:", r2_score(Y, pred2, sample_weight=df['weights']))
到目前为止,我发现(我使用其他软件包(例如 R、SPSS 来评估结果)测试了不同的组合)是我必须将权重应用于fit()
函数并r2_score()
函数以获得正确的结果(参见上面的示例) . 例如,如果我只将权重应用于fit()
函数而不应用于r2_score()
函数,则报告的 R2 值是错误的(即模型错误)。如果我也将权重应用于predict()
函数,则 R2 值也是错误的(即模型错误)。
fit = lm.fit(X,Y, sample_weight=df['weights'])
pred = fit.predict(X, sample_weight=df['weights'])
print("R2 - With something in between:", r2_score(Y, pred, sample_weight=df['weights']))
但是,由于我只能在 Python 中计算 SHAP 值,因此我无法评估结果。问题是我应该如何应用样本权重来计算 SHAP 值?
仅在拟合函数 (?) 中:
fit = lm.fit(X,Y, sample_weight=df['weights'])
explainer = shap.LinearExplainer(fit, X, feature_dependence = 'independent')
shap_values = explainer.shap_values(X)
或者也在explainer()
函数(?)中:
fit = lm.fit(X,Y, sample_weight=df['weights'])
explainer = shap.LinearExplainer(fit, X, feature_dependence = 'independent',
sample_weight=df['weights'])
shap_values = explainer.shap_values(X)
可能还有其他可能性……但我不知道哪个是正确的。
这是一个小数据样本。
print(df.to_dict())
{'weights': {847: 0.75149, 995: 2.18378, 14403: 1.10852, 13311: 0.85934, 17019: 0.95337, 23707: 0.8899, 29562: 0.96819, 30627: 1.16261, 15187: 1.15915, 24179: 1.09833}, 'Funnel': {847: 5.0, 995: 2.0, 14403: 2.0, 13311: 4.0, 17019: 2.0, 23707: 2.0, 29562: 2.0, 30627: 4.0, 15187: 4.0, 24179: 5.0}, 'Q23_1': {847: 2.0, 995: 1.0, 14403: 1.0, 13311: 2.0, 17019: 1.0, 23707: 3.0, 29562: 1.0, 30627: 1.0, 15187: 5.0, 24179: 1.0}, 'Q23_2': {847: 2.0, 995: 1.0, 14403: 1.0, 13311: 2.0, 17019: 1.0, 23707: 2.0, 29562: 2.0, 30627: 1.0, 15187: 5.0, 24179: 1.0}, 'Q23_3': {847: 1.0, 995: 1.0, 14403: 1.0, 13311: 3.0, 17019: 2.0, 23707: 3.0, 29562: 2.0, 30627: 1.0, 15187: 5.0, 24179: 1.0}, 'Q23_4': {847: 3.0, 995: 3.0, 14403: 2.0, 13311: 3.0, 17019: 3.0, 23707: 3.0, 29562: 1.0, 30627: 1.0, 15187: 5.0, 24179: 1.0}, 'Q23_5': {847: 3.0, 995: 3.0, 14403: 2.0, 13311: 2.0, 17019: 2.0, 23707: 2.0, 29562: 3.0, 30627: 1.0, 15187: 5.0, 24179: 1.0}, 'Q23_6': {847: 3.0, 995: 3.0, 14403: 4.0, 13311: 3.0, 17019: 2.0, 23707: 4.0, 29562: 3.0, 30627: 1.0, 15187: 5.0, 24179: 2.0}, 'Q23_7': {847: 2.0, 995: 1.0, 14403: 1.0, 13311: 3.0, 17019: 3.0, 23707: 3.0, 29562: 2.0, 30627: 1.0, 15187: 5.0, 24179: 1.0}, 'Q23_8': {847: 5.0, 995: 4.0, 14403: 5.0, 13311: 4.0, 17019: 2.0, 23707: 4.0, 29562: 1.0, 30627: 1.0, 15187: 5.0, 24179: 1.0}, 'Q23_9': {847: 3.0, 995: 2.0, 14403: 1.0, 13311: 4.0, 17019: 2.0, 23707: 2.0, 29562: 1.0, 30627: 1.0, 15187: 5.0, 24179: 1.0}, 'Q23_10': {847: 1.0, 995: 2.0, 14403: 2.0, 13311: 3.0, 17019: 2.0, 23707: 2.0, 29562: 3.0, 30627: 1.0, 15187: 5.0, 24179: 1.0}, 'Q23_11': {847: 2.0, 995: 1.0, 14403: 2.0, 13311: 2.0, 17019: 3.0, 23707: 3.0, 29562: 2.0, 30627: 1.0, 15187: 2.0, 24179: 1.0}, 'Q23_12': {847: 2.0, 995: 2.0, 14403: 1.0, 13311: 2.0, 17019: 1.0, 23707: 2.0, 29562: 2.0, 30627: 1.0, 15187: 5.0, 24179: 1.0}, 'Q23_13': {847: 2.0, 995: 2.0, 14403: 3.0, 13311: 3.0, 17019: 1.0, 23707: 2.0, 29562: 4.0, 30627: 1.0, 15187: 5.0, 24179: 1.0}, 'Q23_14': {847: 3.0, 995: 2.0, 14403: 3.0, 13311: 3.0, 17019: 1.0, 23707: 3.0, 29562: 1.0, 30627: 1.0, 15187: 5.0, 24179: 1.0}, 'Q23_15': {847: 1.0, 995: 2.0, 14403: 1.0, 13311: 2.0, 17019: 2.0, 23707: 3.0, 29562: 1.0, 30627: 1.0, 15187: 5.0, 24179: 1.0}}