我不知道是否有办法使用 patsy 公式对类别进行子集化,但您可以在 DataFrame 中进行。
例如
import numpy as np
import pandas as pd
import statsmodels.api as sm
# sample data
size = 100
np.random.seed(1)
countries = ['IT', 'UK', 'US', 'FR', 'ES']
df = pd.DataFrame({
'outcome': np.random.random(size),
'Country': np.random.choice(countries, size)
})
df['Country'] = df.Country.astype('category')
print(df.Country)
0 ES
1 IT
2 UK
3 US
4 UK
..
95 FR
96 UK
97 ES
98 UK
99 US
Name: Country, Length: 100, dtype: category
Categories (5, object): ['ES', 'FR', 'IT', 'UK', 'US']
假设我们要删除 Category"US"
# create a deep copy excluding 'US'
_df = df[df.Country!='US'].copy(deep=True)
print(_df.Country)
0 ES
1 IT
2 UK
4 UK
5 ES
..
94 UK
95 FR
96 UK
97 ES
98 UK
Name: Country, Length: 83, dtype: category
Categories (5, object): ['ES', 'FR', 'IT', 'UK', 'US']
即使"US"
DataFrame 中没有更多具有类别的元素,类别仍然存在。如果我们在模型中使用这个 DataFrame statsmodels
,我们会得到一个singular matrix
错误,所以我们需要删除未使用的类别
# remove unused category 'US'
_df['Country'] = _df.Country.cat.remove_unused_categories()
print(_df.Country)
0 ES
1 IT
2 UK
4 UK
5 ES
..
94 UK
95 FR
96 UK
97 ES
98 UK
Name: Country, Length: 83, dtype: category
Categories (4, object): ['ES', 'FR', 'IT', 'UK']
现在我们可以拟合一个模型
mod = sm.Logit.from_formula('outcome ~ Country', data=_df)
fit = mod.fit()
print(fit.summary())
Optimization terminated successfully.
Current function value: 0.684054
Iterations 4
Logit Regression Results
==============================================================================
Dep. Variable: outcome No. Observations: 83
Model: Logit Df Residuals: 79
Method: MLE Df Model: 3
Date: Sun, 16 May 2021 Pseudo R-squ.: 0.01179
Time: 22:43:37 Log-Likelihood: -56.776
converged: True LL-Null: -57.454
Covariance Type: nonrobust LLR p-value: 0.7160
=================================================================================
coef std err z P>|z| [0.025 0.975]
---------------------------------------------------------------------------------
Intercept -0.1493 0.438 -0.341 0.733 -1.007 0.708
Country[T.FR] 0.4129 0.614 0.673 0.501 -0.790 1.616
Country[T.IT] -0.1223 0.607 -0.201 0.840 -1.312 1.068
Country[T.UK] 0.1027 0.653 0.157 0.875 -1.178 1.383
=================================================================================