为了方便起见,当我使用带有 cross_val_score 的嵌套交叉验证时,我想知道 GridSearch 的结果。
使用 cross_val_score 时,您会得到一个分数数组。接收拟合的估计器或为该估计器选择参数的摘要会很有用。
我知道您可以自己执行此操作,但只需手动实现交叉验证,但如果可以与 cross_val_score 结合使用会更方便。
有什么方法可以做到这一点,或者这是一个建议的功能?
为了方便起见,当我使用带有 cross_val_score 的嵌套交叉验证时,我想知道 GridSearch 的结果。
使用 cross_val_score 时,您会得到一个分数数组。接收拟合的估计器或为该估计器选择参数的摘要会很有用。
我知道您可以自己执行此操作,但只需手动实现交叉验证,但如果可以与 cross_val_score 结合使用会更方便。
有什么方法可以做到这一点,或者这是一个建议的功能?
GridSearchCV
scikit-learn 中的类已经在内部进行了交叉验证。您可以将任何CV 迭代cv
器作为GridSearchCV
.
你的问题的答案是它是一个建议的功能。不幸的是,您无法使用cross_val_score
(截至目前,scikit 0.14)获得适合嵌套交叉验证的模型的最佳参数。
看这个例子:
from sklearn import datasets
from sklearn.linear_model import LinearRegression
from sklearn.grid_search import GridSearchCV
from sklearn.cross_validation import cross_val_score
digits = datasets.load_digits()
X = digits.data
y = digits.target
hyperparams = [{'fit_intercept':[True, False]}]
algo = LinearRegression()
grid = GridSearchCV(algo, hyperparams, cv=5, scoring='mean_squared_error')
# Nested cross validation
cross_val_score(grid, X, y)
grid.best_score_
[Out]:
---------------------------------------------------------------------------
AttributeError Traceback (most recent call last)
<ipython-input-4-4c4ac83c58fb> in <module>()
15 # Nested cross validation
16 cross_val_score(grid, X, y)
---> 17 grid.best_score_
AttributeError: 'GridSearchCV' object has no attribute 'best_score_'
(另请注意,您从中获得的分数cross_val_score
不是 中定义的分数scoring
,这里是均方误差。您看到的是最佳估计器的分数函数。这里描述了 v0.14 的错误。)
在sklearn v0.20.0
(将于 2018 年底发布)中,受过训练的估计器会在cross_validate
需要时由函数公开。
在此处查看新功能的相应拉取请求。像这样的东西会起作用:
from sklearn.metrics.scorer import check_scoring
from sklearn.model_selection import cross_validate
scorer = check_scoring(estimator=gridSearch, scoring=scoring)
cvRet = cross_validate(estimator=gridSearch, X=X, y=y,
scoring={'score': scorer}, cv=cvOuter,
return_train_score=False,
return_estimator=True,
n_jobs=nJobs)
scores = cvRet['test_score'] # Equivalent to output of cross_val_score()
estimators = cvRet['estimator']
如果return_estimator=True
,估计量可以从返回的字典中检索为cvRet['estimator']
。存储的列表cvRet['test_score']
相当于 的输出cross_val_score
。请参阅此处如何cross_val_score()
通过cross_validate()
.