我正在尝试在 CIFAR-10 数据集上训练 DC-GAN。我使用二元交叉熵作为鉴别器和生成器的损失函数(附加了不可训练的鉴别器)。如果我使用 Adam 优化器进行训练,GAN 训练良好。但是如果我用 SGD 替换优化器,训练就会变得混乱。生成器精度从某个更高的点开始,随着迭代,它变为 0 并保持在那里。鉴别器的准确度从某个较低的点开始,达到 0.5 左右(预期的,对吗?)。奇怪的是生成器损失函数随着迭代而增加。我虽然可能是步骤太高了。我尝试更改步长。我尝试使用 SGD 的动量。在所有这些情况下,生成器在开始时可能会或可能不会减少,但随后肯定会增加。所以,我认为我的模型存在固有的错误。我知道训练深度模型很困难,而 GAN 更难,但必须有一些原因/启发式来解释为什么会发生这种情况。任何输入表示赞赏。我是神经网络、深度学习的新手,因此对 GAN 也很陌生。
这是我的代码:Cifar10Models.py
from keras import Sequential
from keras.initializers import TruncatedNormal
from keras.layers import Activation, BatchNormalization, Conv2D, Conv2DTranspose, Dense, Flatten, LeakyReLU, Reshape
from keras.optimizers import SGD
class DcGan:
def __init__(self, print_model_summary: bool = False):
self.generator_model = None
self.discriminator_model = None
self.concatenated_model = None
self.print_model_summary = print_model_summary
def build_generator_model(self):
if self.generator_model:
return self.generator_model
self.generator_model = Sequential()
self.generator_model.add(Dense(4 * 4 * 512, input_dim=100,
kernel_initializer=TruncatedNormal(mean=0.0, stddev=0.02)))
self.generator_model.add(BatchNormalization(momentum=0.5))
self.generator_model.add(Activation('relu'))
self.generator_model.add(Reshape((4, 4, 512)))
self.generator_model.add(Conv2DTranspose(256, 3, strides=2, padding='same',
kernel_initializer=TruncatedNormal(mean=0.0, stddev=0.02)))
self.generator_model.add(BatchNormalization(momentum=0.5))
self.generator_model.add(Activation('relu'))
self.generator_model.add(Conv2DTranspose(128, 3, strides=2, padding='same',
kernel_initializer=TruncatedNormal(mean=0.0, stddev=0.02)))
self.generator_model.add(BatchNormalization(momentum=0.5))
self.generator_model.add(Activation('relu'))
self.generator_model.add(Conv2DTranspose(64, 3, strides=2, padding='same',
kernel_initializer=TruncatedNormal(mean=0.0, stddev=0.02)))
self.generator_model.add(BatchNormalization(momentum=0.5))
self.generator_model.add(Activation('relu'))
self.generator_model.add(Conv2D(3, 3, padding='same',
kernel_initializer=TruncatedNormal(mean=0.0, stddev=0.02)))
self.generator_model.add(Activation('tanh'))
if self.print_model_summary:
self.generator_model.summary()
return self.generator_model
def build_discriminator_model(self):
if self.discriminator_model:
return self.discriminator_model
self.discriminator_model = Sequential()
self.discriminator_model.add(Conv2D(128, 3, strides=2, input_shape=(32, 32, 3), padding='same',
kernel_initializer=TruncatedNormal(mean=0.0, stddev=0.02)))
self.discriminator_model.add(LeakyReLU(alpha=0.2))
self.discriminator_model.add(Conv2D(256, 3, strides=2, padding='same',
kernel_initializer=TruncatedNormal(mean=0.0, stddev=0.02)))
self.generator_model.add(BatchNormalization(momentum=0.5))
self.discriminator_model.add(LeakyReLU(alpha=0.2))
self.discriminator_model.add(Conv2D(512, 3, strides=2, padding='same',
kernel_initializer=TruncatedNormal(mean=0.0, stddev=0.02)))
self.generator_model.add(BatchNormalization(momentum=0.5))
self.discriminator_model.add(LeakyReLU(alpha=0.2))
self.discriminator_model.add(Conv2D(1024, 3, strides=2, padding='same',
kernel_initializer=TruncatedNormal(mean=0.0, stddev=0.02)))
self.generator_model.add(BatchNormalization(momentum=0.5))
self.discriminator_model.add(LeakyReLU(alpha=0.2))
self.discriminator_model.add(Flatten())
self.discriminator_model.add(Dense(1, kernel_initializer=TruncatedNormal(mean=0.0, stddev=0.02)))
self.generator_model.add(BatchNormalization(momentum=0.5))
self.discriminator_model.add(Activation('sigmoid'))
if self.print_model_summary:
self.discriminator_model.summary()
return self.discriminator_model
def build_concatenated_model(self):
if self.concatenated_model:
return self.concatenated_model
self.concatenated_model = Sequential()
self.concatenated_model.add(self.generator_model)
self.concatenated_model.add(self.discriminator_model)
if self.print_model_summary:
self.concatenated_model.summary()
return self.concatenated_model
def build_dc_gan(self):
self.build_generator_model()
self.build_discriminator_model()
self.build_concatenated_model()
self.discriminator_model.trainable = True
optimizer = SGD(lr=0.0002)
self.discriminator_model.compile(loss='binary_crossentropy', optimizer=optimizer, metrics=['accuracy'])
self.discriminator_model.trainable = False
optimizer = SGD(lr=0.0001)
self.concatenated_model.compile(loss='binary_crossentropy', optimizer=optimizer, metrics=['accuracy'])
self.discriminator_model.trainable = True
Cifar10Trainer.py:
# Shree KRISHNAya Namaha
# Based on https://towardsdatascience.com/gan-by-example-using-keras-on-tensorflow-backend-1a6d515a60d0
import os
import datetime
import numpy
import time
from keras.datasets import cifar10
from keras.utils import np_utils
from matplotlib import pyplot as plt
import Cifar10Models
log_file_name = 'logs.csv'
class Cifar10Trainer:
def __init__(self):
self.x_train, self.y_train = self.get_train_and_test_data()
self.dc_gan = Cifar10Models.DcGan()
self.dc_gan.build_dc_gan()
@staticmethod
def get_train_and_test_data():
(x_train, y_train), _ = cifar10.load_data()
x_train = x_train.reshape(x_train.shape[0], x_train.shape[1], x_train.shape[2], 3)
# Generator output has tanh activation whose range is [-1,1]
x_train = (x_train.astype('float32') * 2 / 255) - 1
y_train = np_utils.to_categorical(y_train, 10)
return x_train, y_train
def train(self, train_steps=10000, batch_size=128, log_interval=10, save_interval=100,
output_folder_path='./Trained_Models/'):
self.initialize_log(output_folder_path)
self.sample_real_images(output_folder_path)
for i in range(train_steps):
# Get real (Database) Images
images_real = self.x_train[numpy.random.randint(0, self.x_train.shape[0], size=batch_size), :, :, :]
# Generate Fake Images
noise = numpy.random.uniform(-1.0, 1.0, size=[batch_size, 100])
images_fake = self.dc_gan.generator_model.predict(noise)
# Train discriminator on both real and fake images
x = numpy.concatenate((images_real, images_fake), axis=0)
y = numpy.ones([2 * batch_size, 1])
y[batch_size:, :] = 0
d_loss = self.dc_gan.discriminator_model.train_on_batch(x, y)
# Train generator i.e. concatenated model
noise = numpy.random.uniform(-1.0, 1.0, size=[batch_size, 100])
y = numpy.ones([batch_size, 1])
g_loss = self.dc_gan.concatenated_model.train_on_batch(noise, y)
# Print Logs, Save Models, generate sample images
if (i + 1) % log_interval == 0:
self.log_progress(output_folder_path, i + 1, g_loss, d_loss)
if (i + 1) % save_interval == 0:
self.save_models(output_folder_path, i + 1)
self.generate_images(output_folder_path, i + 1)
@staticmethod
def initialize_log(output_folder_path):
log_line = 'Iteration No, Generator Loss, Generator Accuracy, Discriminator Loss, Discriminator Accuracy, ' \
'Time\n'
with open(os.path.join(output_folder_path, log_file_name), 'w') as log_file:
log_file.write(log_line)
@staticmethod
def log_progress(output_folder_path, iteration_no, g_loss, d_loss):
log_line = '{0:05},{1:2.4f},{2:0.4f},{3:2.4f},{4:0.4f},{5}\n' \
.format(iteration_no, g_loss[0], g_loss[1], d_loss[0], d_loss[1],
datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S'))
with open(os.path.join(output_folder_path, log_file_name), 'a') as log_file:
log_file.write(log_line)
print(log_line)
def save_models(self, output_folder_path, iteration_no):
self.dc_gan.generator_model.save(
os.path.join(output_folder_path, 'generator_model_{0}.h5'.format(iteration_no)))
self.dc_gan.discriminator_model.save(
os.path.join(output_folder_path, 'discriminator_model_{0}.h5'.format(iteration_no)))
self.dc_gan.concatenated_model.save(
os.path.join(output_folder_path, 'concatenated_model_{0}.h5'.format(iteration_no)))
def sample_real_images(self, output_folder_path):
filepath = os.path.join(output_folder_path, 'CIFAR10_Sample_Real_Images.png')
i = numpy.random.randint(0, self.x_train.shape[0], 16)
images = self.x_train[i, :, :, :]
plt.figure(figsize=(10, 10))
for i in range(16):
plt.subplot(4, 4, i + 1)
image = images[i, :, :, :]
image = numpy.reshape(image, [32, 32, 3])
plt.imshow(image)
plt.axis('off')
plt.tight_layout()
plt.savefig(filepath)
plt.close('all')
def generate_images(self, output_folder_path, iteration_no, noise=None):
filepath = os.path.join(output_folder_path, 'CIFAR10_Gen_Image{0}.png'.format(iteration_no))
if noise is None:
noise = numpy.random.uniform(-1, 1, size=[16, 100])
# Generator output has tanh activation whose range is [-1,1]
images = (self.dc_gan.generator_model.predict(noise) + 1) / 2
plt.figure(figsize=(10, 10))
for i in range(16):
plt.subplot(4, 4, i + 1)
image = images[i, :, :, :]
image = numpy.reshape(image, [32, 32, 3])
plt.imshow(image)
plt.axis('off')
plt.tight_layout()
plt.savefig(filepath)
plt.close('all')
def main():
cifar10_trainer = Cifar10Trainer()
cifar10_trainer.train(train_steps=10000, log_interval=10, save_interval=100)
del cifar10_trainer.dc_gan
return
if __name__ == '__main__':
start_time = time.time()
print('Program Started at {0}'.format(time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(start_time))))
main()
end_time = time.time()
print('Program Ended at {0}'.format(time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(end_time))))
print('Total Execution Time: {0}s'.format(datetime.timedelta(seconds=end_time - start_time)))
部分图表如下: