3

我已经训练了一个 seq2seq tensorflow 模型,用于将句子从英语翻译成西班牙语。我为 615 700 步训练了一个模型,并成功保存了模型检查点。我的英语和西班牙语句子的训练数据量是 200 000。我想重新训练这个模型,从 615 700 步中得到 10K 新数据句子。为此,我正在使用序列到序列 tensoflow 模型。如何从最后一个检查点开始重新训练模型?是我用于翻译的链接。

我的火车文件夹中有 3 种类型的文件:

.index
.meta
.data
and checkpoint file.

我的新训练数据集文件分别是europarl_train.es-en.eneuroparl_train.es-en.es用于英语和西班牙语句子。

我编写了一个代码来加载我的模型 .meta 文件和权重

import data_utils
import seq2seq_model
import translate
import tensorflow as tf

with tf.Session() as sess:    
    saver = tf.train.import_meta_graph('/home/i9/L-T_Model_Training/16_NOV_MODEL/train/translate.ckpt-615700.meta')
    saver.restore(sess,tf.train.latest_checkpoint('/home/i9/L-T_Model_Training/16_NOV_MODEL/train/.'))

我怎样才能开始保留这个数据集?

4

1 回答 1

0

节省

根据TensorFlow 版本 2 文档,您可以使用tf.train.Checkpointtf.train.CheckpointManager类来保存模型。考虑以下示例:

checkpoint_dir = './training_checkpoints'       # custom directory
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")
checkpoint = tf.train.Checkpoint(model=model)   # your model variable name
manager = tf.train.CheckpointManager(checkpoint=checkpoint, directory=checkpoint_dir, max_to_keep=3)           # max_to_keep means how much of last checkpoints number you like to keep

现在,如果您想保存模型,请键入:manager.save()

加载

再次定义 checkpoint 和 checkpointManager 并运行以下代码:

if manager.latest_checkpoint:
    checkpoint.restore((manager.latest_checkpoint)).assert_consumed()
    print("Restored from {}".format(manager.latest_checkpoint))

如果您遇到类似 (AssertionError: Unresolved object in checkpoint (root)) 的错误,请替换assert_consumedexpect_partial. (去这里的区别:链接

模型已从检查点加载。现在您可以加载数据并修复形状并继续训练您的模型。

于 2020-12-14T16:22:36.020 回答