0

这是github代码

for epoch in range(num_epoch):
for i, (img, _) in enumerate(dataloader):
    num_img = img.size(0)
    # =================train discriminator
    img = img.view(num_img, -1)
    real_img = Variable(img).cuda()
    real_label = Variable(torch.ones(num_img)).cuda()
    fake_label = Variable(torch.zeros(num_img)).cuda()

我不明白训练代码中的 torch.ones 和 torch.zeros 是什么。

任何人都可以解释一下吗?

4

1 回答 1

1

你可能知道:在 GAN 中,生成器试图通过说服假样本是真样本来欺骗鉴别器。训练有素的鉴别器以区分真实示例和虚假示例。另一方面,生成器经过训练以生成看起来非常接近真实示例的(假)示例。


分析您共享的代码/示例(在链接中)。

生成器:是一个简单的前馈神经网络。28 * 28生成器从随机(噪声)分布生成图像。生成器的目标是生成看起来像真实图像的图像。

鉴别器:是一个简单的前馈神经网络。判别器提供给定图像的 sigmoid ([0, 1]) 分数。判别器的目标是给假图像打低分(~0),给真实图像打高分(~1)。本质上,鉴别器想要区分真实图像和假图像。


代码是如何工作的?

首先,为鉴别器提供真实图像的示例,并根据鉴别器的预测分数计算损失。

# compute loss of real_img
real_out = D(real_img)
d_loss_real = criterion(real_out, real_label)
real_scores = real_out  # closer to 1 means better

然后向鉴别器提供生成器生成的假图像。损失是根据鉴别器在假样本上的得分来计算的。

# compute loss of fake_img
z = Variable(torch.randn(num_img, z_dimension)).cuda()
fake_img = G(z)
fake_out = D(fake_img)
d_loss_fake = criterion(fake_out, fake_label)
fake_scores = fake_out  # closer to 0 means better

本质上,生成器和判别器正在相互竞争以成为实现目标的专家。我们可以这样想:如果我们有一个完美的生成器,那么它会创建与真实示例完全相同的假示例,而判别器将无法区分它们,反之亦然。


您上面提供的代码只是使用torch.zeros()and创建标签torch.ones()。您可以简单地将其视为真实和虚假图像的二进制标签。

于 2018-03-16T03:36:55.023 回答