错误消息中params
提到的参数是指传递给train()
函数的参数。如果您使用 python API 的 sklearn 类,则某些参数也可用作分类器__init__()
方法中的关键字参数。
例子:
import lightgbm as lgb
import numpy as np
from sklearn import datasets
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
iris = datasets.load_iris()
X_train, X_test, y_train, y_test = train_test_split(iris.data,
iris.target,
test_size=0.2)
lgb_train = lgb.Dataset(X_train, y_train)
# here are the parameters you need
params = {
'task': 'train',
'boosting_type': 'gbdt',
'objective': 'multiclass',
'num_class': 3,
'max_bin': 4 # <-- max_bin
}
gbm = lgb.train(params,
lgb_train,
num_boost_round=20)
y_pred = gbm.predict(X_test, num_iteration=gbm.best_iteration)
y_pred = np.argmax(y_pred, axis=1)
print("Accuracy: ", accuracy_score(y_test, y_pred))
对于详细的示例,我建议查看LGBM 附带的python 示例。