建议的解决方案
重用您共享的存储库中的代码,这里有一些建议的修改,以沿着您的生成器和鉴别器训练分类器(它们的架构和其他损失保持不变):
from keras import backend as K
from keras.models import Sequential
from keras.layers.core import Dense, Dropout, Activation, Flatten
from keras.layers.convolutional import Convolution2D, MaxPooling2D
def lenet_classifier_model(nb_classes):
# Snipped by Fabien Tanc - https://www.kaggle.com/ftence/keras-cnn-inspired-by-lenet-5
# Replace with your favorite classifier...
model = Sequential()
model.add(Convolution2D(12, 5, 5, activation='relu', input_shape=in_shape, init='he_normal'))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Convolution2D(25, 5, 5, activation='relu', init='he_normal'))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Flatten())
model.add(Dense(180, activation='relu', init='he_normal'))
model.add(Dropout(0.5))
model.add(Dense(100, activation='relu', init='he_normal'))
model.add(Dropout(0.5))
model.add(Dense(nb_classes, activation='softmax', init='he_normal'))
def generator_containing_discriminator_and_classifier(generator, discriminator, classifier):
inputs = Input((IN_CH, img_cols, img_rows))
x_generator = generator(inputs)
merged = merge([inputs, x_generator], mode='concat', concat_axis=1)
discriminator.trainable = False
x_discriminator = discriminator(merged)
classifier.trainable = False
x_classifier = classifier(x_generator)
model = Model(input=inputs, output=[x_generator, x_discriminator, x_classifier])
return model
def train(BATCH_SIZE):
(X_train, Y_train, LABEL_train) = get_data('train') # replace with your data here
X_train = (X_train.astype(np.float32) - 127.5) / 127.5
Y_train = (Y_train.astype(np.float32) - 127.5) / 127.5
discriminator = discriminator_model()
generator = generator_model()
classifier = lenet_classifier_model(6)
generator.summary()
discriminator_and_classifier_on_generator = generator_containing_discriminator_and_classifier(
generator, discriminator, classifier)
d_optim = Adagrad(lr=0.005)
g_optim = Adagrad(lr=0.005)
generator.compile(loss='mse', optimizer="rmsprop")
discriminator_and_classifier_on_generator.compile(
loss=[generator_l1_loss, discriminator_on_generator_loss, "categorical_crossentropy"],
optimizer="rmsprop")
discriminator.trainable = True
discriminator.compile(loss=discriminator_loss, optimizer="rmsprop")
classifier.trainable = True
classifier.compile(loss="categorical_crossentropy", optimizer="rmsprop")
for epoch in range(100):
print("Epoch is", epoch)
print("Number of batches", int(X_train.shape[0] / BATCH_SIZE))
for index in range(int(X_train.shape[0] / BATCH_SIZE)):
image_batch = Y_train[index * BATCH_SIZE:(index + 1) * BATCH_SIZE]
label_batch = LABEL_train[index * BATCH_SIZE:(index + 1) * BATCH_SIZE] # replace with your data here
generated_images = generator.predict(X_train[index * BATCH_SIZE:(index + 1) * BATCH_SIZE])
if index % 20 == 0:
image = combine_images(generated_images)
image = image * 127.5 + 127.5
image = np.swapaxes(image, 0, 2)
cv2.imwrite(str(epoch) + "_" + str(index) + ".png", image)
# Image.fromarray(image.astype(np.uint8)).save(str(epoch)+"_"+str(index)+".png")
# Training D:
real_pairs = np.concatenate((X_train[index * BATCH_SIZE:(index + 1) * BATCH_SIZE, :, :, :], image_batch),
axis=1)
fake_pairs = np.concatenate(
(X_train[index * BATCH_SIZE:(index + 1) * BATCH_SIZE, :, :, :], generated_images), axis=1)
X = np.concatenate((real_pairs, fake_pairs))
y = np.zeros((20, 1, 64, 64)) # [1] * BATCH_SIZE + [0] * BATCH_SIZE
d_loss = discriminator.train_on_batch(X, y)
print("batch %d d_loss : %f" % (index, d_loss))
discriminator.trainable = False
# Training C:
c_loss = classifier.train_on_batch(image_batch, label_batch)
print("batch %d c_loss : %f" % (index, c_loss))
classifier.trainable = False
# Train G:
g_loss = discriminator_and_classifier_on_generator.train_on_batch(
X_train[index * BATCH_SIZE:(index + 1) * BATCH_SIZE, :, :, :],
[image_batch, np.ones((10, 1, 64, 64)), label_batch])
discriminator.trainable = True
classifier.trainable = True
print("batch %d g_loss : %f" % (index, g_loss[1]))
if index % 20 == 0:
generator.save_weights('generator', True)
discriminator.save_weights('discriminator', True)
理论细节
我认为对于条件 GAN 的工作原理以及鉴别器在此类方案中的作用存在一些误解。
鉴别器的作用
在 GAN 训练 [4] 的 min-max 游戏中,判别D
器与生成器G
(您真正关心的网络)进行对抗,以便在D
的审查下,G
在输出真实结果方面变得更好。
为此,D
经过训练,可以将真实样本与来自 的样本区分开来G
;whileG
被训练为D
通过在目标分布之后生成真实的结果/结果来愚弄。
注意:在条件 GAN 的情况下,即 GAN 将输入样本从一个域A
(例如真实图片)映射到另一个域B
(例如草图),D
通常输入堆叠在一起的样本对,并且必须区分“真实”对(来自的输入样本A
+ 来自 的相应目标样本B
)和“假”对(来自 的输入样本A
+ 来自 的相应输出G
) [1, 2]
训练条件生成器D
(而不是简单地G
单独训练,仅使用 L1/L2 损失,例如 DAE)提高了 的采样能力G
,迫使它输出清晰、真实的结果,而不是试图平均分布。
即使鉴别器可以有多个子网络来覆盖其他任务(见下一段),D
也应该至少保留一个子网络/输出来覆盖其主要任务:将真实样本与生成的样本区分开来。要求D
回归进一步的语义信息(例如类)可能会干扰这个主要目的。
注意:D
输出通常不是简单的标量/布尔值。通常有一个鉴别器(例如 PatchGAN [1, 2])返回一个概率矩阵,评估从其输入生成的补丁的真实性。
条件 GAN
传统的 GAN 以无监督的方式进行训练,以从作为输入的随机噪声向量生成真实数据(例如图像)。[4]
如前所述,条件 GAN 具有进一步的输入条件。沿着/而不是噪声向量,它们将来自域的样本作为输入,A
并从域返回相应的样本B
。A
可以是完全不同的模态,例如B = sketch image
while A = discrete label
; B = volumetric data
而A = RGB image
等 [3]
这样的 GAN 也可以通过多个输入来调节,例如A = real image + discrete label
while B = sketch image
。介绍这种方法的著名工作是InfoGAN [5]。它介绍了如何在多个连续或离散输入(例如A = digit class + writing type
,B = handwritten digit image
G
最大化 cGAN 的互信息
InfoGAN 鉴别器有 2 个头/子网络来覆盖其 2 个任务 [5]:
- 一个负责
D1
人进行传统的真实/生成的区分——G
必须最小化这个结果,即它必须愚弄D1
以使其无法区分真实形式的生成数据;
- 另一个头
D2
(也称为Q
网络)试图回归输入A
信息 -G
必须最大化这个结果,即它必须输出“显示”请求的语义信息的数据(参见G
条件输入与其输出之间的互信息最大化)。
例如,您可以在此处找到 Keras 实现:https ://github.com/eriklindernoren/Keras-GAN/tree/master/infogan 。
一些工作正在使用类似的方案来改进对 GAN 生成内容的控制,方法是使用提供的标签并最大化这些输入和G
输出之间的互信息 [6, 7]。基本思想总是相同的:
- 给定域的一些输入,训练
G
生成域的元素;B
A
- 训练
D
区分“真实”/“假”结果——G
必须尽量减少这一点;
- 训练
Q
(例如分类器;可以与 共享层)以估计来自样本D
的原始A
输入——必须最大化这一点)。B
G
包起来
在您的情况下,您似乎有以下训练数据:
你想训练一个生成器G
,以便给定图像Ia
及其类标签c
,它会输出正确的草图图像Ib'
。
总而言之,你有很多信息,你可以监督你在条件图像和条件标签上的训练......受上述方法 [1, 2, 5, 6, 7] 的启发,这里有一个使用所有这些信息来训练你的条件的可能方法G
:
网络G
:
- 输入:
Ia
+c
- 输出:
Ib'
- 架构:由您决定(例如 U-Net、ResNet、...)
- 损失:
Ib'
&之间的L1/L2损失Ib
,-D
损失,Q
损失
网络D
:
- 输入:
Ia
+ Ib
(真对),Ia
+ Ib'
(假对)
- 输出:“虚假”标量/矩阵
- 架构:由你决定(例如 PatchGAN)
- 损失:“虚假”估计的交叉熵
网络Q
:
- 输入:(
Ib
真实样本,用于训练Q
),Ib'
(假样本,反向传播时G
)
- 输出:(
c'
估计类)
- 架构:由您决定(例如 LeNet、ResNet、VGG、...)
c
损失:和之间的交叉熵c'
训练阶段:
- 训练
D
一批真实对Ia
+Ib
然后训练一批假对Ia
+ Ib'
;
- 训练
Q
一批真实样本Ib
;
- 固定
D
和Q
重量;
- Train
G
,将其生成的输出传递Ib'
给它们D
并Q
通过它们进行反向传播。
注意:这是一个非常粗略的架构描述。我建议您阅读文献([1, 5, 6, 7] 作为一个好的开始)以获得更多细节,也许是更详尽的解决方案。
参考
- 伊索拉、菲利普等人。“使用条件对抗网络进行图像到图像的翻译。” arXiv 预印本(2017 年)。http://openaccess.thecvf.com/content_cvpr_2017/papers/Isola_Image-To-Image_Translation_With_CVPR_2017_paper.pdf
- 朱俊彦等人。“使用循环一致的对抗网络进行未配对的图像到图像转换。” arXiv 预印本 arXiv:1703.10593 (2017)。http://openaccess.thecvf.com/content_ICCV_2017/papers/Zhu_Unpaired_Image-To-Image_Translation_ICCV_2017_paper.pdf
- 米尔扎、迈赫迪和西蒙·奥辛德罗。“有条件的生成对抗网络。” arXiv 预印本 arXiv:1411.1784 (2014)。https://arxiv.org/pdf/1411.1784
- Goodfellow,伊恩等人。“生成对抗网络。” 神经信息处理系统的进展。2014. http://papers.nips.cc/paper/5423-generation-adversarial-nets.pdf
- 陈,习,等。“Infogan:通过信息最大化生成对抗网络的可解释表示学习。” 神经信息处理系统的进展。2016. http://papers.nips.cc/paper/6399-infogan-interpretable-representation-learning-by-information-maximizing-generation-adversarial-nets.pdf
- Lee、Minhyeok 和 Junhee Seok。“可控生成对抗网络”。arXiv 预印本 arXiv:1708.00598 (2017)。https://arxiv.org/pdf/1708.00598.pdf
- Odena、Augustus、Christopher Olah 和 Jonathon Shlens。“使用辅助分类器甘斯的条件图像合成。” arXiv 预印本 arXiv:1610.09585 (2016)。http://proceedings.mlr.press/v70/odena17a/odena17a.pdf