2

我想计算线性模型的 SHAP 值。对于回归,我必须使用样本权重。

问题是我无法评估样本权重是否实际应用于正确计算 SHAP 值。

这是一个例子。

# Import libraries 

import shap
import pandas 
import numpy 
from sklearn.linear_model import LinearRegression
from sklearn.metrics import r2_score


# Setting up the data and the model

df.head()

weights Funnel  Q23_1   Q23_2   Q23_3   Q23_4   Q23_5   Q23_6   Q23_7    
Q23_8   Q23_9   Q23_10  Q23_11  Q23_12  Q23_13  Q23_14  Q23_15
847 0.75149 5.0 2.0 2.0 1.0 3.0 3.0 3.0 2.0 5.0 3.0 1.0 2.0 2.0 2.0 3.0 1.0
995 2.18378 2.0 1.0 1.0 1.0 3.0 3.0 3.0 1.0 4.0 2.0 2.0 1.0 2.0 2.0 2.0 2.0
14403   1.10852 2.0 1.0 1.0 1.0 2.0 2.0 4.0 1.0 5.0 1.0 2.0 2.0 1.0 3.0 3.0  
 1.0
13311   0.85934 4.0 2.0 2.0 3.0 3.0 2.0 3.0 3.0 4.0 4.0 3.0 2.0 2.0 3.0 3.0  
 2.0
17019   0.95337 2.0 1.0 1.0 2.0 3.0 2.0 2.0 3.0 2.0 2.0 2.0 3.0 1.0 1.0 1.0  
2.0

Y = df_t.drop(['Funnel', 'weights'], axis=1)
X = df_t[['Funnel']]

lm = LinearRegression()

首先,我计算没有权重的回归。

fit = lm.fit(X,Y)

pred = fit.predict(X)

print("R2 - No Weights:", r2_score(Y,pred))

然后我用权重计算回归。

fit = lm.fit(X,Y, sample_weight=df['weights'])

pred = fit.predict(X)

print("R2 - Wit weights:", r2_score(Y, pred2, sample_weight=df['weights']))

到目前为止,我发现(我使用其他软件包(例如 R、SPSS 来评估结果)测试了不同的组合)是我必须将权重应用于fit()函数并r2_score()函数以获得正确的结果(参见上面的示例) . 例如,如果我只将权重应用于fit()函数而不应用于r2_score()函数,则报告的 R2 值是错误的(即模型错误)。如果我也将权重应用于predict()函数,则 R2 值也是错误的(即模型错误)。

fit = lm.fit(X,Y, sample_weight=df['weights'])

pred = fit.predict(X, sample_weight=df['weights'])

print("R2 - With something in between:", r2_score(Y, pred, sample_weight=df['weights']))

但是,由于我只能在 Python 中计算 SHAP 值,因此我无法评估结果。问题是我应该如何应用样本权重来计算 SHAP 值?

仅在拟合函数 (?) 中:

fit = lm.fit(X,Y, sample_weight=df['weights'])
explainer = shap.LinearExplainer(fit, X, feature_dependence = 'independent')
shap_values = explainer.shap_values(X)

或者也在explainer()函数(?)中:

fit = lm.fit(X,Y, sample_weight=df['weights'])
explainer = shap.LinearExplainer(fit, X, feature_dependence = 'independent', 
sample_weight=df['weights'])
shap_values = explainer.shap_values(X)

可能还有其他可能性……但我不知道哪个是正确的。

这是一个小数据样本。

print(df.to_dict())
{'weights': {847: 0.75149, 995: 2.18378, 14403: 1.10852, 13311: 0.85934, 17019: 0.95337, 23707: 0.8899, 29562: 0.96819, 30627: 1.16261, 15187: 1.15915, 24179: 1.09833}, 'Funnel': {847: 5.0, 995: 2.0, 14403: 2.0, 13311: 4.0, 17019: 2.0, 23707: 2.0, 29562: 2.0, 30627: 4.0, 15187: 4.0, 24179: 5.0}, 'Q23_1': {847: 2.0, 995: 1.0, 14403: 1.0, 13311: 2.0, 17019: 1.0, 23707: 3.0, 29562: 1.0, 30627: 1.0, 15187: 5.0, 24179: 1.0}, 'Q23_2': {847: 2.0, 995: 1.0, 14403: 1.0, 13311: 2.0, 17019: 1.0, 23707: 2.0, 29562: 2.0, 30627: 1.0, 15187: 5.0, 24179: 1.0}, 'Q23_3': {847: 1.0, 995: 1.0, 14403: 1.0, 13311: 3.0, 17019: 2.0, 23707: 3.0, 29562: 2.0, 30627: 1.0, 15187: 5.0, 24179: 1.0}, 'Q23_4': {847: 3.0, 995: 3.0, 14403: 2.0, 13311: 3.0, 17019: 3.0, 23707: 3.0, 29562: 1.0, 30627: 1.0, 15187: 5.0, 24179: 1.0}, 'Q23_5': {847: 3.0, 995: 3.0, 14403: 2.0, 13311: 2.0, 17019: 2.0, 23707: 2.0, 29562: 3.0, 30627: 1.0, 15187: 5.0, 24179: 1.0}, 'Q23_6': {847: 3.0, 995: 3.0, 14403: 4.0, 13311: 3.0, 17019: 2.0, 23707: 4.0, 29562: 3.0, 30627: 1.0, 15187: 5.0, 24179: 2.0}, 'Q23_7': {847: 2.0, 995: 1.0, 14403: 1.0, 13311: 3.0, 17019: 3.0, 23707: 3.0, 29562: 2.0, 30627: 1.0, 15187: 5.0, 24179: 1.0}, 'Q23_8': {847: 5.0, 995: 4.0, 14403: 5.0, 13311: 4.0, 17019: 2.0, 23707: 4.0, 29562: 1.0, 30627: 1.0, 15187: 5.0, 24179: 1.0}, 'Q23_9': {847: 3.0, 995: 2.0, 14403: 1.0, 13311: 4.0, 17019: 2.0, 23707: 2.0, 29562: 1.0, 30627: 1.0, 15187: 5.0, 24179: 1.0}, 'Q23_10': {847: 1.0, 995: 2.0, 14403: 2.0, 13311: 3.0, 17019: 2.0, 23707: 2.0, 29562: 3.0, 30627: 1.0, 15187: 5.0, 24179: 1.0}, 'Q23_11': {847: 2.0, 995: 1.0, 14403: 2.0, 13311: 2.0, 17019: 3.0, 23707: 3.0, 29562: 2.0, 30627: 1.0, 15187: 2.0, 24179: 1.0}, 'Q23_12': {847: 2.0, 995: 2.0, 14403: 1.0, 13311: 2.0, 17019: 1.0, 23707: 2.0, 29562: 2.0, 30627: 1.0, 15187: 5.0, 24179: 1.0}, 'Q23_13': {847: 2.0, 995: 2.0, 14403: 3.0, 13311: 3.0, 17019: 1.0, 23707: 2.0, 29562: 4.0, 30627: 1.0, 15187: 5.0, 24179: 1.0}, 'Q23_14': {847: 3.0, 995: 2.0, 14403: 3.0, 13311: 3.0, 17019: 1.0, 23707: 3.0, 29562: 1.0, 30627: 1.0, 15187: 5.0, 24179: 1.0}, 'Q23_15': {847: 1.0, 995: 2.0, 14403: 1.0, 13311: 2.0, 17019: 2.0, 23707: 3.0, 29562: 1.0, 30627: 1.0, 15187: 5.0, 24179: 1.0}}
4

0 回答 0