7

我正在使用交叉验证训练模型,如下所示:

classifier = lgb.Booster(
    params=params, 
    train_set=lgb_train_set,
)

result = lgb.cv(
    init_model=classifier,
    params=params, 
    train_set=lgb_train_set,
    num_boost_round=1000,
    early_stopping_rounds=20,
    verbose_eval=50,
    shuffle=True
)

我想通过多次运行第二个命令(可能使用新的训练集或使用不同的参数)来继续训练模型,它将继续改进模型。

但是,当我尝试这样做时,很明显该模型每次都是从头开始的。

有没有不同的方法来做我想做的事情?

4

4 回答 4

13

可以使用lightgbm.train 的init_model选项解决,它接受两个对象之一

  1. LightGBM 模型的文件名,或
  2. 一个 lightgbm Booster 对象

代码说明:

import numpy as np 
import lightgbm as lgb

data = np.random.rand(1000, 10) # 1000 entities, each contains 10 features
label = np.random.randint(2, size=1000) # binary target
train_data = lgb.Dataset(data, label=label, free_raw_data=False)
params = {}

#Initialize with 10 iterations
gbm_init = lgb.train(params, train_data, num_boost_round = 10)
print("Initial iter# %d" %gbm_init.current_iteration())

# Example of option #1 (pass a file):
gbm_init.save_model('model.txt')
gbm = lgb.train(params, train_data, num_boost_round = 10,
                init_model='model.txt')
print("Option 1 current iter# %d" %gbm.current_iteration())


# Example of option #2 (pass a lightgbm Booster object):
gbm_2 = lgb.train(params, train_data, num_boost_round = 10,
                init_model = gbm_init)
print("Option 2 current iter# %d" %gbm_2.current_iteration())

https://lightgbm.readthedocs.io/en/latest/pythonapi/lightgbm.train.html

于 2019-10-10T16:59:49.240 回答
4

要进行培训,您必须lgb.train再次进行并确保您包含在参数中init_model='model.txt'。为确认您已正确完成培训期间的信息反馈,应从lgb.cv. 然后像这样保存模型最佳迭代bst.save_model('model.txt', num_iteration=bst.best_iteration)

于 2018-02-07T23:04:36.517 回答
3

init_model本身不起作用。我们必须keep_training_booster为方法设置参数train

lgb_params = {
  'keep_training_booster': True,
  'objective': 'regression',
  'verbosity': 100,
}
lgb.train(lgb_params, init_model= ...)

或作为函数参数:

lgb.train(lgb_params, keep_training_booster=True, init_model= ...)
于 2021-01-31T18:11:33.333 回答
2

似乎 lightgbm 不允许将模型实例作为 init_model 传递,因为它只需要文件名:

init_model (string or None, optional (default=None)) – 用于继续训练的 LightGBM 模型或 Booster 实例的文件名。

关联

于 2018-01-03T09:01:18.773 回答