0

我是 Python 新手,一直在使用 tensorflow 进行文本分类。我想知道这个文本分类模型是否可以用我将来可能获得的每一个新数据进行更新,这样我就不必从头开始训练模型。此外,有时随着时间的推移,课程的数量可能也会更多,因为我主要处理客户数据。是否可以通过使用现有的检查点来使用包含更多类的数据来更新这个现有的文本分类模型?

4

1 回答 1

0

鉴于您要问 2 个不同的问题,我现在分别回答这两个问题:

1) 是的,您可以使用获得的新数据继续训练。这很简单,您只需要像现在一样恢复模型即可使用它。您应该运行优化器操作,而不是运行诸如输出或预测之类的占位符。这转化为以下代码:

model = build_model() # this is the function that build the model graph
saver = tf.train.Saver()
with tf.Session() as session:
    saver.restore(session, "/path/to/model.ckpt")
########### keep training #########
data_x, data_y = load_new_data(new_data_path)
for epoch in range(1, epochs+1):
    all_losses = list()
    num_batches = 0
    for b_x, b_y in batchify(data_x, data_y)
        _, loss = session.run([model.opt, model.loss], feed_dict={model.input:b_x, model.input_y : b_y}
        all_losses.append(loss * len(batch_x))
        num_batches += 1
    print("epoch %d - loss: %2f" % (epoch, sum(losses) / num_batches))

请注意,您现在需要模型定义的操作的名称,以便运行优化器 (model.opt) 和损失操作 (model.loss) 来训练和监控训练期间的损失。

2)如果要更改要使用的标签数量,则要复杂一些。如果您的网络是 1 层前馈,那么没有什么可做的,因为您需要更改矩阵维数,然后您需要从头开始重新训练所有内容。另一方面,如果您有例如多层网络(例如,进行分类的 LSTM + 密集层),那么您可以恢复旧模型的权重并从头开始训练最后一层。为此,我建议您阅读此答案https://stackoverflow.com/a/41642426/4186749

于 2018-08-30T15:12:26.223 回答