0

我正在使用 Python 对 NIR 光谱数据进行高斯过程回归。我可以使用 GPR 获得一些结果,并希望优化 GPR 的参数。我正在尝试使用 GridSearchCV 来优化参数,但我不断收到错误消息,并且找不到任何人们将 GridSearchCV 用于高斯过程的示例(来自 sklearn.gaussian_process)。我的快速问题是我是否可以将 GridSearchCV 用于 GPR。如果没有,您建议使用什么来优化参数。

这是我的错误:

---------------------------------------------------
-# Tuning hyper-parameters for precision

Traceback (most recent call last):

File "", line 1, in runfile('C:/Users/hkim.N04485/Desktop/Python/untitled14.py', wdir='C:/Users/hkim.N04485/Desktop/Python')

File "C:\Users\hkim.N04485\Anaconda2\lib\site-packages\spyderlib\widgets\externalshell\sitecustomize.py", line 699, in runfile execfile(filename, namespace)

File "C:\Users\hkim.N04485\Anaconda2\lib\site-packages\spyderlib\widgets\externalshell\sitecustomize.py", line 74, in execfile exec(compile(scripttext, filename, 'exec'), glob, loc)

File "C:/Users/hkim.N04485/Desktop/Python/untitled14.py", line 39, in gp.fit(X1, y1_glucose)

File "C:\Users\hkim.N04485\Anaconda2\lib\site-packages\sklearn\grid_search.py", line 804, in fit return self._fit(X, y, ParameterGrid(self.param_grid))

File "C:\Users\hkim.N04485\Anaconda2\lib\site-packages\sklearn\grid_search.py", line 553, in _fit for parameters in parameter_iterable

File "C:\Users\hkim.N04485\Anaconda2\lib\site-packages\sklearn\externals\joblib\parallel.py", line 804, in call while self.dispatch_one_batch(iterator):

File "C:\Users\hkim.N04485\Anaconda2\lib\site-packages\sklearn\externals\joblib\parallel.py", line 662, in dispatch_one_batch self._dispatch(tasks)

File "C:\Users\hkim.N04485\Anaconda2\lib\site-packages\sklearn\externals\joblib\parallel.py", line 570, in _dispatch job = ImmediateComputeBatch(batch)

File "C:\Users\hkim.N04485\Anaconda2\lib\site-packages\sklearn\externals\joblib\parallel.py", line 183, in init self.results = batch()

File "C:\Users\hkim.N04485\Anaconda2\lib\site-packages\sklearn\externals\joblib\parallel.py", line 72, in call return [func(*args, **kwargs) for func, args, kwargs in self.items]

File "C:\Users\hkim.N04485\Anaconda2\lib\site-packages\sklearn\cross_validation.py", line 1550, in _fit_and_score test_score = _score(estimator, X_test, y_test, scorer)

File "C:\Users\hkim.N04485\Anaconda2\lib\site-packages\sklearn\cross_validation.py", line 1606, in _score score = scorer(estimator, X_test, y_test)

File "C:\Users\hkim.N04485\Anaconda2\lib\site-packages\sklearn\metrics\scorer.py", line 90, in call **self._kwargs)

File "C:\Users\hkim.N04485\Anaconda2\lib\site-packages\sklearn\metrics\classification.py", line 1203, in precision_score sample_weight=sample_weight)

File "C:\Users\hkim.N04485\Anaconda2\lib\site-packages\sklearn\metrics\classification.py", line 956, in precision_recall_fscore_support y_type, y_true, y_pred = _check_targets(y_true, y_pred)

File "C:\Users\hkim.N04485\Anaconda2\lib\site-packages\sklearn\metrics\classification.py", line 82, in _check_targets "".format(type_true, type_pred))

ValueError: Can't handle mix of multiclass and continuous

我该如何解决?


这是我的代码。

tuned_parameters = [{'corr':['squared_exponential'], 'theta0': [0.01, 0.2, 0.8, 1.]},
                    {'corr':['cubic'], 'theta0': [0.01, 0.2,  0.8, 1.]}]

scores = ['precision', 'recall']

xy_line=(0,1200)


for score in scores:
    print("# Tuning hyper-parameters for %s" % score)
    print()

gp = GridSearchCV(GaussianProcess(normalize=False), tuned_parameters, cv=5,
                   scoring='%s_weighted' % score)
gp.fit(X1, y1_glucose)

print("Best parameters set found on development set:")
print()
print(gp.best_params_)
print()
print("Grid scores on development set:")
print()
for params, mean_score, scores in gp.grid_scores_:
    print("%0.3f (+/-%0.03f) for %r"
          % (mean_score, scores.std() * 2, params))

y_true, y_pred = y2_glucose, gp.predict(X2)

# Scatter plot (reference vs predicted )
fig, ax = plt.subplots(figsize=(11,13))
ax.scatter(y2_glucose,y_pred)
ax.plot(xy_line, xy_line, 'r--')
major_ticks = np.arange(-300,2000,100)
minor_ticks = np.arange(0,1201,100)
ax.set_xticks(minor_ticks) 
ax.set_yticks(major_ticks) 
ax.grid()
plt.title('1')
ax.set_xlabel('Reference')
ax.set_ylabel('Predicted')
4

0 回答 0