使用 Gridsearch,我在拟合训练数据后找到了最佳超参数:
model_xgb = XGBClassifier()
n_estimators = [50, 100, 150, 200]
max_depth = [2, 4, 6, 8]
param_grid = dict(max_depth=max_depth, n_estimators=n_estimators)
kfold = StratifiedKFold(n_splits=10, shuffle=True, random_state=7)
grid_search = GridSearchCV(model_xgb, param_grid, scoring="neg_log_loss", n_jobs=-1, cv=kfold, verbose=1)
grid_result = grid_search.fit(train_X, y_train)
最佳答案是使用 获得{'max_depth': 4, 'n_estimators': 50}
。这就是我使用这些超参数创建新模型的原因:
model_xgb_tn = XGBClassifier(n_estimators=50,max_depth=4,objective='multi:softprob')
当我尝试将模型拟合到我的数据时:model_xgb_tn.fit(train_X,y_train)
,我收到一个KeyError: 'base_score'
. 当我什至没有使用超参数时,我只是无法理解为什么会出现 KeyError。
下面是错误代码:
---------------------------------------------------------------------------
KeyError Traceback (most recent call last)
~\Anaconda3\lib\site-packages\IPython\core\formatters.py in __call__(self, obj, include, exclude)
968
969 if method is not None:
--> 970 return method(include=include, exclude=exclude)
971 return None
972 else:
~\Anaconda3\lib\site-packages\sklearn\base.py in _repr_mimebundle_(self, **kwargs)
461 def _repr_mimebundle_(self, **kwargs):
462 """Mime bundle used by jupyter kernels to display estimator"""
--> 463 output = {"text/plain": repr(self)}
464 if get_config()["display"] == 'diagram':
465 output["text/html"] = estimator_html_repr(self)
~\Anaconda3\lib\site-packages\sklearn\base.py in __repr__(self, N_CHAR_MAX)
277 n_max_elements_to_show=N_MAX_ELEMENTS_TO_SHOW)
278
--> 279 repr_ = pp.pformat(self)
280
281 # Use bruteforce ellipsis when there are a lot of non-blank characters
~\Anaconda3\lib\pprint.py in pformat(self, object)
142 def pformat(self, object):
143 sio = _StringIO()
--> 144 self._format(object, sio, 0, 0, {}, 0)
145 return sio.getvalue()
146
~\Anaconda3\lib\pprint.py in _format(self, object, stream, indent, allowance, context, level)
159 self._readable = False
160 return
--> 161 rep = self._repr(object, context, level)
162 max_width = self._width - indent - allowance
163 if len(rep) > max_width:
~\Anaconda3\lib\pprint.py in _repr(self, object, context, level)
391 def _repr(self, object, context, level):
392 repr, readable, recursive = self.format(object, context.copy(),
--> 393 self._depth, level)
394 if not readable:
395 self._readable = False
~\Anaconda3\lib\site-packages\sklearn\utils\_pprint.py in format(self, object, context, maxlevels, level)
168 def format(self, object, context, maxlevels, level):
169 return _safe_repr(object, context, maxlevels, level,
--> 170 changed_only=self._changed_only)
171
172 def _pprint_estimator(self, object, stream, indent, allowance, context,
~\Anaconda3\lib\site-packages\sklearn\utils\_pprint.py in _safe_repr(object, context, maxlevels, level, changed_only)
412 recursive = False
413 if changed_only:
--> 414 params = _changed_params(object)
415 else:
416 params = object.get_params(deep=False)
~\Anaconda3\lib\site-packages\sklearn\utils\_pprint.py in _changed_params(estimator)
96 init_params = {name: param.default for name, param in init_params.items()}
97 for k, v in params.items():
---> 98 if (repr(v) != repr(init_params[k]) and
99 not (is_scalar_nan(init_params[k]) and is_scalar_nan(v))):
100 filtered_params[k] = v
KeyError: 'base_score'
---------------------------------------------------------------------------
KeyError Traceback (most recent call last)
~\Anaconda3\lib\site-packages\IPython\core\formatters.py in __call__(self, obj)
700 type_pprinters=self.type_printers,
701 deferred_pprinters=self.deferred_printers)
--> 702 printer.pretty(obj)
703 printer.flush()
704 return stream.getvalue()
~\Anaconda3\lib\site-packages\IPython\lib\pretty.py in pretty(self, obj)
400 if cls is not object \
401 and callable(cls.__dict__.get('__repr__')):
--> 402 return _repr_pprint(obj, self, cycle)
403
404 return _default_pprint(obj, self, cycle)
~\Anaconda3\lib\site-packages\IPython\lib\pretty.py in _repr_pprint(obj, p, cycle)
695 """A pprint that just redirects to the normal repr function."""
696 # Find newlines and replace them with p.break_()
--> 697 output = repr(obj)
698 for idx,output_line in enumerate(output.splitlines()):
699 if idx:
~\Anaconda3\lib\site-packages\sklearn\base.py in __repr__(self, N_CHAR_MAX)
277 n_max_elements_to_show=N_MAX_ELEMENTS_TO_SHOW)
278
--> 279 repr_ = pp.pformat(self)
280
281 # Use bruteforce ellipsis when there are a lot of non-blank characters
~\Anaconda3\lib\pprint.py in pformat(self, object)
142 def pformat(self, object):
143 sio = _StringIO()
--> 144 self._format(object, sio, 0, 0, {}, 0)
145 return sio.getvalue()
146
~\Anaconda3\lib\pprint.py in _format(self, object, stream, indent, allowance, context, level)
159 self._readable = False
160 return
--> 161 rep = self._repr(object, context, level)
162 max_width = self._width - indent - allowance
163 if len(rep) > max_width:
~\Anaconda3\lib\pprint.py in _repr(self, object, context, level)
391 def _repr(self, object, context, level):
392 repr, readable, recursive = self.format(object, context.copy(),
--> 393 self._depth, level)
394 if not readable:
395 self._readable = False
~\Anaconda3\lib\site-packages\sklearn\utils\_pprint.py in format(self, object, context, maxlevels, level)
168 def format(self, object, context, maxlevels, level):
169 return _safe_repr(object, context, maxlevels, level,
--> 170 changed_only=self._changed_only)
171
172 def _pprint_estimator(self, object, stream, indent, allowance, context,
~\Anaconda3\lib\site-packages\sklearn\utils\_pprint.py in _safe_repr(object, context, maxlevels, level, changed_only)
412 recursive = False
413 if changed_only:
--> 414 params = _changed_params(object)
415 else:
416 params = object.get_params(deep=False)
~\Anaconda3\lib\site-packages\sklearn\utils\_pprint.py in _changed_params(estimator)
96 init_params = {name: param.default for name, param in init_params.items()}
97 for k, v in params.items():
---> 98 if (repr(v) != repr(init_params[k]) and
99 not (is_scalar_nan(init_params[k]) and is_scalar_nan(v))):
100 filtered_params[k] = v
KeyError: 'base_score'