9

为了方便起见,当我使用带有 cross_val_score 的嵌套交叉验证时,我想知道 GridSearch 的结果。

使用 cross_val_score 时,您会得到一个分数数组。接收拟合的估计器或为该估计器选择参数的摘要会很有用。

我知道您可以自己执行此操作,但只需手动实现交叉验证,但如果可以与 cross_val_score 结合使用会更方便。

有什么方法可以做到这一点,或者这是一个建议的功能?

4

3 回答 3

6

GridSearchCVscikit-learn 中的类已经在内部进行了交叉验证。您可以将任何CV 迭代cv器作为GridSearchCV.

于 2013-06-24T13:07:45.467 回答
2

你的问题的答案是它是一个建议的功能。不幸的是,您无法使用cross_val_score(截至目前,scikit 0.14)获得适合嵌套交叉验证的模型的最佳参数。

看这个例子:

from sklearn import datasets
from sklearn.linear_model import LinearRegression
from sklearn.grid_search import GridSearchCV
from sklearn.cross_validation import cross_val_score

digits = datasets.load_digits()
X = digits.data
y = digits.target

hyperparams = [{'fit_intercept':[True, False]}]
algo = LinearRegression()

grid = GridSearchCV(algo, hyperparams, cv=5, scoring='mean_squared_error')

# Nested cross validation
cross_val_score(grid, X, y)
grid.best_score_

[Out]:
---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
<ipython-input-4-4c4ac83c58fb> in <module>()
     15 # Nested cross validation
     16 cross_val_score(grid, X, y)
---> 17 grid.best_score_

AttributeError: 'GridSearchCV' object has no attribute 'best_score_'

(另请注意,您从中获得的分数cross_val_score不是 中定义的分数scoring,这里是均方误差。您看到的是最佳估计器的分数函数。这里描述了 v0.14 的错误。)

于 2014-07-03T12:41:37.630 回答
1

sklearn v0.20.0(将于 2018 年底发布)中,受过训练的估计器会在cross_validate需要时由函数公开。

在此处查看新功能的相应拉取请求。像这样的东西会起作用:

from sklearn.metrics.scorer import check_scoring
from sklearn.model_selection import cross_validate
scorer = check_scoring(estimator=gridSearch, scoring=scoring)
cvRet = cross_validate(estimator=gridSearch, X=X, y=y,
                       scoring={'score': scorer}, cv=cvOuter,
                       return_train_score=False,
                       return_estimator=True,
                       n_jobs=nJobs)

scores = cvRet['test_score']  # Equivalent to output of cross_val_score()
estimators = cvRet['estimator']

如果return_estimator=True,估计量可以从返回的字典中检索为cvRet['estimator']。存储的列表cvRet['test_score']相当于 的输出cross_val_score。请参阅此处如何cross_val_score()通过cross_validate().

于 2018-08-26T22:58:21.557 回答