我一直在尝试实现我自己的 GAN,但取得了一些有限的成功。我在教程中对他们如何训练它进行了一些调整,并想知道这种更剧烈的变化是否可行。我所做的第一个更改是让鉴别器对 n+1 个类别进行分类,其中 n 可能是,例如 MNIST 中的 10,第 n+1 个类别是假类别。鉴别器是我从头开始为一个非常好的分类器制作的导入架构。那么我的 GAN 就会有传统 NLLLoss 的“对立面”。
这是棘手的非传统计算部分。所以因为我在鉴别器的最后一层有一个 softmax,所以鉴别器的输出总是从 0 到 1。所以我可以为生成器创建一个自定义损失函数,使其成为 NLLLoss 的水平翻转,它试图使确保鉴别器不会将赝品分类为第 n+1 类。这个想法是,只要它们不是第 n + 1 类,我不在乎假货属于哪一类。这种错误分类的行为是我想要为生成器最大化的行为。这是我在desmos上绘制的函数以提供一些可视化:
https://www.desmos.com/calculator/6gdqs28ihk
我的生成器损失函数的实际代码是下面的代码,而我的鉴别器损失函数是传统的 NLLLoss
loss_G = torch.mean(-torch.log(1 - outputG.float()[:,classes]))
请让我知道这是完全错误的还是有更简单的方法。