1

我使用以下代码在训练模型的循环之外创建了一个检查点管理器:

checkpoint_path = "./checkpoints/train"

ckpt = tf.train.Checkpoint(object_1=object_1)

ckpt_manager = tf.train.CheckpointManager(ckpt, checkpoint_path, max_to_keep=1)

然后在训练模型时,我ckpt_save_path = ckpt_manager.save()会在每个 epoch 之后保存变量。

鉴于我想实现一种提前停止方法,我需要在特定时期之后恢复所有变量并使用这些变量进行预测。如果我使用上面的代码保存变量(希望保存过程是正确的?),那么在假设 epoch e 之后如何恢复变量。我知道我可以先创建相同的变量和对象,然后使用下面的代码来恢复最新的检查点,但不知道如何恢复特定的检查点(如 epoch number e 之后的变量)而不是最新的。

ckpt.restore(ckpt_manager.latest_checkpoint).assert_consumed()

谢谢,

4

1 回答 1

1

是的,您需要生成带有纪元号的文件名文本字符串。

c_manager = tf.train.CheckpointManager(checkpoint, ...)

if EPOCH == '':
    if c_manager.latest_checkpoint:
        tf.print("-----------Restoring from {}-----------".format(
            c_manager.latest_checkpoint))
        checkpoint.restore(c_manager.latest_checkpoint)
        EPOCH = c_manager.latest_checkpoint.split(sep='ckpt-')[-1]
    else:
        tf.print("-----------Initializing from scratch-----------")
else:    
    checkpoint_fname = CHECKPOINT_SAVE_DIR + 'ckpt-' + str(EPOCH)
    tf.print("-----------Restoring from {}-----------".format(checkpoint_fname))
    checkpoint.restore(checkpoint_fname)
于 2020-07-16T11:39:12.640 回答