0

我正在为我的项目使用原始 DCGAN MNIST 代码(keras)。我的任务是生成一个数组,然后我会从中计算出一些 observables。我在每个时期后保存模型,以便我可以找到哪个时期我得到最好的可观察值。我已经使用了 50 个 Epoch,所以我保存了 50 个检查点。现在我想使用一些中间保存的检查点生成数组(通过生成器),那么我应该如何从中加载数据?我用于保存检查点的代码如下:



checkpoint_dir = "./training_checkpoints"
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")
checkpoint = tf.train.Checkpoint(generator_optimizer=generator_optimizer,
                                 discriminator_optimizer=discriminator_optimizer,
                                 generator=generator,
                                 discriminator=discriminator)

它保存两种类型的文件:ckpt-1.data-00000-of-00001 和 ckpt-1.index。

如何从中生成该数组?(注意:我想要的“数组”类似于在 MNIST 案例中生成的像素数组)

4

1 回答 1

1

您可以使用tf.train.CheckpointManager加载最新的检查点或您喜欢的任何检查点,然后generator根据随机噪声使用您的模型生成一些图像:

import tensorflow as tf

checkpoint_dir = "./training_checkpoints"
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")
checkpoint = tf.train.Checkpoint(generator_optimizer=generator_optimizer,
                                 discriminator_optimizer=discriminator_optimizer,
                                 generator=generator,
                                 discriminator=discriminator)
ckpt_manager = tf.train.CheckpointManager(checkpoint, checkpoint_dir, max_to_keep=5)
if ckpt_manager.latest_checkpoint:
    checkpoint.restore(ckpt_manager.latest_checkpoint)
    # You can also access previous checkpoints like this: ckpt_manager.checkpoints[3]
    print ('Latest checkpoint restored!!')
    batch_size = 8
    latent_dim = 32
    noise = tf.random.normal([batch_size, latent_dim])
    generated_images = generator(noise, training=False)
    # Plot and/or save your images.
于 2021-11-01T07:54:28.913 回答