2

我正在尝试在 CIFAR-10 数据集上训练 DC-GAN。我使用二元交叉熵作为鉴别器和生成器的损失函数(附加了不可训练的鉴别器)。如果我使用 Adam 优化器进行训练,GAN 训练良好。但是如果我用 SGD 替换优化器,训练就会变得混乱。生成器精度从某个更高的点开始,随着迭代,它变为 0 并保持在那里。鉴别器的准确度从某个较低的点开始,达到 0.5 左右(预期的,对吗?)。奇怪的是生成器损失函数随着迭代而增加。我虽然可能是步骤太高了。我尝试更改步长。我尝试使用 SGD 的动量。在所有这些情况下,生成器在开始时可能会或可能不会减少,但随后肯定会增加。所以,我认为我的模型存在固有的错误。我知道训练深度模型很困难,而 GAN 更难,但必须有一些原因/启发式来解释为什么会发生这种情况。任何输入表示赞赏。我是神经网络、深度学习的新手,因此对 GAN 也很陌生。

这是我的代码:Cifar10Models.py

from keras import Sequential
from keras.initializers import TruncatedNormal
from keras.layers import Activation, BatchNormalization, Conv2D, Conv2DTranspose, Dense, Flatten, LeakyReLU, Reshape
from keras.optimizers import SGD


class DcGan:
    def __init__(self, print_model_summary: bool = False):
        self.generator_model = None
        self.discriminator_model = None
        self.concatenated_model = None
        self.print_model_summary = print_model_summary

    def build_generator_model(self):
        if self.generator_model:
            return self.generator_model

        self.generator_model = Sequential()
        self.generator_model.add(Dense(4 * 4 * 512, input_dim=100,
                                       kernel_initializer=TruncatedNormal(mean=0.0, stddev=0.02)))
        self.generator_model.add(BatchNormalization(momentum=0.5))
        self.generator_model.add(Activation('relu'))
        self.generator_model.add(Reshape((4, 4, 512)))

        self.generator_model.add(Conv2DTranspose(256, 3, strides=2, padding='same',
                                                 kernel_initializer=TruncatedNormal(mean=0.0, stddev=0.02)))
        self.generator_model.add(BatchNormalization(momentum=0.5))
        self.generator_model.add(Activation('relu'))

        self.generator_model.add(Conv2DTranspose(128, 3, strides=2, padding='same',
                                                 kernel_initializer=TruncatedNormal(mean=0.0, stddev=0.02)))
        self.generator_model.add(BatchNormalization(momentum=0.5))
        self.generator_model.add(Activation('relu'))

        self.generator_model.add(Conv2DTranspose(64, 3, strides=2, padding='same',
                                                 kernel_initializer=TruncatedNormal(mean=0.0, stddev=0.02)))
        self.generator_model.add(BatchNormalization(momentum=0.5))
        self.generator_model.add(Activation('relu'))

        self.generator_model.add(Conv2D(3, 3, padding='same',
                                        kernel_initializer=TruncatedNormal(mean=0.0, stddev=0.02)))
        self.generator_model.add(Activation('tanh'))

        if self.print_model_summary:
            self.generator_model.summary()

        return self.generator_model

    def build_discriminator_model(self):
        if self.discriminator_model:
            return self.discriminator_model

        self.discriminator_model = Sequential()
        self.discriminator_model.add(Conv2D(128, 3, strides=2, input_shape=(32, 32, 3), padding='same',
                                            kernel_initializer=TruncatedNormal(mean=0.0, stddev=0.02)))
        self.discriminator_model.add(LeakyReLU(alpha=0.2))

        self.discriminator_model.add(Conv2D(256, 3, strides=2, padding='same',
                                            kernel_initializer=TruncatedNormal(mean=0.0, stddev=0.02)))
        self.generator_model.add(BatchNormalization(momentum=0.5))
        self.discriminator_model.add(LeakyReLU(alpha=0.2))

        self.discriminator_model.add(Conv2D(512, 3, strides=2, padding='same',
                                            kernel_initializer=TruncatedNormal(mean=0.0, stddev=0.02)))
        self.generator_model.add(BatchNormalization(momentum=0.5))
        self.discriminator_model.add(LeakyReLU(alpha=0.2))

        self.discriminator_model.add(Conv2D(1024, 3, strides=2, padding='same',
                                            kernel_initializer=TruncatedNormal(mean=0.0, stddev=0.02)))
        self.generator_model.add(BatchNormalization(momentum=0.5))
        self.discriminator_model.add(LeakyReLU(alpha=0.2))

        self.discriminator_model.add(Flatten())
        self.discriminator_model.add(Dense(1, kernel_initializer=TruncatedNormal(mean=0.0, stddev=0.02)))
        self.generator_model.add(BatchNormalization(momentum=0.5))
        self.discriminator_model.add(Activation('sigmoid'))

        if self.print_model_summary:
            self.discriminator_model.summary()

        return self.discriminator_model

    def build_concatenated_model(self):
        if self.concatenated_model:
            return self.concatenated_model

        self.concatenated_model = Sequential()
        self.concatenated_model.add(self.generator_model)
        self.concatenated_model.add(self.discriminator_model)

        if self.print_model_summary:
            self.concatenated_model.summary()

        return self.concatenated_model

    def build_dc_gan(self):
        self.build_generator_model()
        self.build_discriminator_model()
        self.build_concatenated_model()

        self.discriminator_model.trainable = True
        optimizer = SGD(lr=0.0002)
        self.discriminator_model.compile(loss='binary_crossentropy', optimizer=optimizer, metrics=['accuracy'])
        self.discriminator_model.trainable = False
        optimizer = SGD(lr=0.0001)
        self.concatenated_model.compile(loss='binary_crossentropy', optimizer=optimizer, metrics=['accuracy'])
        self.discriminator_model.trainable = True

Cifar10Trainer.py:

# Shree KRISHNAya Namaha
# Based on https://towardsdatascience.com/gan-by-example-using-keras-on-tensorflow-backend-1a6d515a60d0

import os

import datetime
import numpy
import time
from keras.datasets import cifar10
from keras.utils import np_utils
from matplotlib import pyplot as plt

import Cifar10Models

log_file_name = 'logs.csv'


class Cifar10Trainer:
    def __init__(self):
        self.x_train, self.y_train = self.get_train_and_test_data()
        self.dc_gan = Cifar10Models.DcGan()
        self.dc_gan.build_dc_gan()

    @staticmethod
    def get_train_and_test_data():
        (x_train, y_train), _ = cifar10.load_data()
        x_train = x_train.reshape(x_train.shape[0], x_train.shape[1], x_train.shape[2], 3)
        # Generator output has tanh activation whose range is [-1,1]
        x_train = (x_train.astype('float32') * 2 / 255) - 1
        y_train = np_utils.to_categorical(y_train, 10)
        return x_train, y_train

    def train(self, train_steps=10000, batch_size=128, log_interval=10, save_interval=100,
              output_folder_path='./Trained_Models/'):
        self.initialize_log(output_folder_path)
        self.sample_real_images(output_folder_path)
        for i in range(train_steps):
            # Get real (Database) Images
            images_real = self.x_train[numpy.random.randint(0, self.x_train.shape[0], size=batch_size), :, :, :]

            # Generate Fake Images
            noise = numpy.random.uniform(-1.0, 1.0, size=[batch_size, 100])
            images_fake = self.dc_gan.generator_model.predict(noise)

            # Train discriminator on both real and fake images
            x = numpy.concatenate((images_real, images_fake), axis=0)
            y = numpy.ones([2 * batch_size, 1])
            y[batch_size:, :] = 0
            d_loss = self.dc_gan.discriminator_model.train_on_batch(x, y)

            # Train generator i.e. concatenated model
            noise = numpy.random.uniform(-1.0, 1.0, size=[batch_size, 100])
            y = numpy.ones([batch_size, 1])
            g_loss = self.dc_gan.concatenated_model.train_on_batch(noise, y)

            # Print Logs, Save Models, generate sample images
            if (i + 1) % log_interval == 0:
                self.log_progress(output_folder_path, i + 1, g_loss, d_loss)
            if (i + 1) % save_interval == 0:
                self.save_models(output_folder_path, i + 1)
                self.generate_images(output_folder_path, i + 1)

    @staticmethod
    def initialize_log(output_folder_path):
        log_line = 'Iteration No, Generator Loss, Generator Accuracy, Discriminator Loss, Discriminator Accuracy, ' \
                   'Time\n'
        with open(os.path.join(output_folder_path, log_file_name), 'w') as log_file:
            log_file.write(log_line)

    @staticmethod
    def log_progress(output_folder_path, iteration_no, g_loss, d_loss):
        log_line = '{0:05},{1:2.4f},{2:0.4f},{3:2.4f},{4:0.4f},{5}\n' \
            .format(iteration_no, g_loss[0], g_loss[1], d_loss[0], d_loss[1],
                    datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S'))
        with open(os.path.join(output_folder_path, log_file_name), 'a') as log_file:
            log_file.write(log_line)
        print(log_line)

    def save_models(self, output_folder_path, iteration_no):
        self.dc_gan.generator_model.save(
            os.path.join(output_folder_path, 'generator_model_{0}.h5'.format(iteration_no)))
        self.dc_gan.discriminator_model.save(
            os.path.join(output_folder_path, 'discriminator_model_{0}.h5'.format(iteration_no)))
        self.dc_gan.concatenated_model.save(
            os.path.join(output_folder_path, 'concatenated_model_{0}.h5'.format(iteration_no)))

    def sample_real_images(self, output_folder_path):
        filepath = os.path.join(output_folder_path, 'CIFAR10_Sample_Real_Images.png')
        i = numpy.random.randint(0, self.x_train.shape[0], 16)
        images = self.x_train[i, :, :, :]
        plt.figure(figsize=(10, 10))
        for i in range(16):
            plt.subplot(4, 4, i + 1)
            image = images[i, :, :, :]
            image = numpy.reshape(image, [32, 32, 3])
            plt.imshow(image)
            plt.axis('off')
        plt.tight_layout()
        plt.savefig(filepath)
        plt.close('all')

    def generate_images(self, output_folder_path, iteration_no, noise=None):
        filepath = os.path.join(output_folder_path, 'CIFAR10_Gen_Image{0}.png'.format(iteration_no))
        if noise is None:
            noise = numpy.random.uniform(-1, 1, size=[16, 100])
        # Generator output has tanh activation whose range is [-1,1]
        images = (self.dc_gan.generator_model.predict(noise) + 1) / 2
        plt.figure(figsize=(10, 10))
        for i in range(16):
            plt.subplot(4, 4, i + 1)
            image = images[i, :, :, :]
            image = numpy.reshape(image, [32, 32, 3])
            plt.imshow(image)
            plt.axis('off')
        plt.tight_layout()
        plt.savefig(filepath)
        plt.close('all')


def main():
    cifar10_trainer = Cifar10Trainer()
    cifar10_trainer.train(train_steps=10000, log_interval=10, save_interval=100)
    del cifar10_trainer.dc_gan
    return


if __name__ == '__main__':
    start_time = time.time()
    print('Program Started at {0}'.format(time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(start_time))))
    main()
    end_time = time.time()
    print('Program Ended at {0}'.format(time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(end_time))))
    print('Total Execution Time: {0}s'.format(datetime.timedelta(seconds=end_time - start_time)))

部分图表如下:

  1. 鉴别器优化器:SGD(lr=0.0001, beta1=0.5)
    生成器优化器:Adam(lr=0.0001, beta1=0.5) 在此处输入图像描述

  2. 鉴别器优化器:SGD(lr=0.0001)
    生成器优化器:SGD(lr=0.0001) 在此处输入图像描述

  3. 鉴别器优化器:SGD(lr=0.0001)
    生成器优化器:SGD(lr=0.001) 在此处输入图像描述

  4. 鉴别器优化器:SGD(lr=0.0001)
    生成器优化器:SGD(lr=0.0005) 在此处输入图像描述

4

0 回答 0