您可以创建一个自定义GridSearchCV
,对估算器的指定参数值执行详尽搜索。
您还可以选择任何可用的评分函数,例如Scikit-learn 中的R 2 Score。但是,您可以使用此处给出的简单公式从 R 2分数计算调整后的 R 2 ,然后在 custom 中实现它。GridSearchCV
from collections import OrderedDict
from itertools import product
from sklearn.feature_selection import RFE
from sklearn.linear_model import LinearRegression
from sklearn.datasets import load_iris
from sklearn.metrics import r2_score
from sklearn.model_selection import StratifiedKFold
def customR2Score(y_true, y_pred, n, p):
"""
Workaround for the adjusted R^2 score
:param y_true: Ground Truth during iterations
:param y_pred: Y predicted during iterations
:param n: the sample size
:param p: the total number of explanatory variables in the model
:return: float, adjusted R^2 score
"""
r2 = r2_score(y_true, y_pred)
return 1 - (1 - r2) * (n - 1) / (n - p - 1)
def CustomGridSearchCV(X, Y, param_grid, n_splits=10, n_repeats=3):
"""
Perform GridSearchCV using adjusted R^2 as Scoring.
Note here we are performing GridSearchCV MANUALLY because adjusted R^2
cannot be used directly in the GridSearchCV function builtin in Scikit-learn
:param X: array_like, shape (n_samples, n_features), Samples.
:param Y: array_like, shape (n_samples, ), Target values.
:param param_grid: Dictionary with parameters names (string) as keys and lists
of parameter settings to try as values, or a list of such
dictionaries, in which case the grids spanned by each dictionary
in the list are explored. This enables searching over any
sequence of parameter settings.
:param n_splits: Number of folds. Must be at least 2. default=10
:param n_repeats: Number of times cross-validator needs to be repeated. default=3
:return: an Ordered Dictionary of the model object and information and best parameters
"""
best_model = OrderedDict()
best_model['best_params'] = {}
best_model['best_train_AdjR2'], best_model['best_cross_AdjR2'] = 0, 0
best_model['best_model'] = None
allParams = OrderedDict()
for key, value in param_grid.items():
allParams[key] = value
for items in product(*allParams.values()):
params = {}
i = 0
for k in allParams.keys():
params[k] = items[i]
i += 1
# at this point, we get different combination of parameters
model_ = RFE(**params)
avg_AdjR2_train = 0.
avg_AdjR2_cross = 0.
for rep in range(n_repeats):
skf = StratifiedKFold(n_splits=n_splits, shuffle=True)
AdjR2_train = 0.
AdjR2_cross = 0.
for train_index, cross_index in skf.split(X, Y):
x_train, x_cross = X[train_index], X[cross_index]
y_train, y_cross = Y[train_index], Y[cross_index]
model_.fit(x_train, y_train)
# find Adjusted R2 of train and cross
y_pred_train = model_.predict(x_train)
y_pred_cross = model_.predict(x_cross)
AdjR2_train += customR2Score(y_train, y_pred_train, len(y_train), model_.n_features_)
AdjR2_cross += customR2Score(y_cross, y_pred_cross, len(y_cross), model_.n_features_)
AdjR2_train /= n_splits
AdjR2_cross /= n_splits
avg_AdjR2_train += AdjR2_train
avg_AdjR2_cross += AdjR2_cross
avg_AdjR2_train /= n_repeats
avg_AdjR2_cross /= n_repeats
# store the results of the first set of parameters combination
if abs(avg_AdjR2_cross) >= abs(best_model['best_cross_AdjR2']):
best_model['best_params'] = params
best_model['best_train_AdjR2'] = avg_AdjR2_train
best_model['best_cross_AdjR2'] = avg_AdjR2_cross
best_model['best_model'] = model_
return best_model
# Dataset for testing
iris = load_iris()
X = iris.data
Y = iris.target
regr = LinearRegression()
param_grid = {'estimator': [regr], # you can try different estimator
'n_features_to_select': range(1, X.shape[1] + 1)}
best_model = CustomGridSearchCV(X, Y, param_grid, n_splits=5, n_repeats=2)
print(best_model)
print(best_model['best_model'].ranking_)
print(best_model['best_model'].support_)
测试结果
OrderedDict([
('best_params', {'n_features_to_select': 3, 'estimator':
LinearRegression(copy_X=True, fit_intercept=True, n_jobs=1, normalize=False)}),
('best_train_AdjR2', 0.9286382985850505), ('best_cross_AdjR2', 0.9188172567358479),
('best_model', RFE(estimator=LinearRegression(copy_X=True, fit_intercept=True,
n_jobs=1, normalize=False), n_features_to_select=3, step=1, verbose=0))])
[1 2 1 1]
[ True False True True]