0

我正在分析一个在图像生成中使用 DCGAN + Reptile的元学习类。

关于这段代码,我有两个问题。

第一个问题:为什么在 DCGAN 训练期间(第 74 行)

training_batch = torch.cat ([real_batch, fake_batch])

是否创建了由真实示例(real_batch)和假示例(fake_batch)组成的training_batch?为什么要通过混合真假图像来进行训练?我见过很多 DCGAN,但从来没有以这种方式进行过训练。

第二个问题:为什么在训练时使用了 normalize_data 函数(第 49 行)和 unnormalize_data 函数(第 55 行)?

def normalize_data(data):
    data *= 2
    data -= 1
    return data


def unnormalize_data(data):
    data += 1
    data /= 2
    return data

该项目使用 Mnist 数据集,如果我想使用像 CIFAR10 这样的颜色数据集,我是否必须修改这些规范化?

4

2 回答 2

1

训练 GAN 涉及为鉴别器提供真实和虚假的示例。通常,您会看到它们是在两个不同的场合给出的。默认情况下torch.cat连接第一个维度 ( dim=0) 上的张量,这是批量维度。因此,它只是将批量大小增加了一倍,其中前半部分是真实图像,后半部分是假图像。

为了计算损失,他们调整了目标,使得前半部分(原始批量大小)被归类为真实的,而后半部分被归类为假的。来自initialize_gan

self.discriminator_targets = torch.tensor([1] * self.batch_size + [-1] * self.batch_size, dtype=torch.float, device=device).view(-1, 1)

图像用 [0, 1] 之间的浮点值表示。规范化会改变它以产生 [-1, 1] 之间的值。GANs 通常在生成器中使用 tanh,因此假图像的值在 [-1, 1] 之间,因此真实图像应该在同一范围内,否则鉴别器区分假图像和真实图像将是微不足道的.

如果要显示这些图像,首先需要对它们进行非规范化,即将它们转换为 [0, 1] 之间的值。

该项目使用 Mnist 数据集,如果我想使用像 CIFAR10 这样的颜色数据集,我是否必须修改这些规范化?

不,您不需要更改它们,因为彩色图像的值也在 [0, 1] 之间,只是有更多的值,代表 3 个通道 (RGB)。

于 2020-05-21T17:36:05.127 回答
0

如果你仔细阅读文档(查看def initialize_gan(self):函数),你会发现

self.meta_g == Generator
self.meta_d == Discriminator

在您引用的行中,fake_batch 被定义为生成器的一部分:

fake_batch = self.meta_g(torch.tensor(np.random.normal(size=(self.batch_size, self.z_shape)), dtype=torch.float, device=device))
training_batch = torch.cat([real_batch, fake_batch])

因此,因为它是一个 GAN,所以你给鉴别器提供了假图像和真实图像,鉴别器必须弄清楚它是哪一个。

关于你的第二个问题,我假设,但我不完全确定这两个函数用于生成假图像?我会仔细检查一下。

这有帮助吗?

于 2020-05-21T16:59:36.813 回答