0

我正在尝试基于 Wassertein-GP 方法为表格 GAN 编写代码。我正在使用带有 keras 库的 python 来编写代码。为此,在定义了 Critic 和 Generator 函数之后,我编写了如下的 train 函数:

def train(g_model, d_model, gan_model, latent_dim, data, n_epochs=20, n_batch=500, n_critic = 5):    
    num_batch = int(data.shape[0] / n_batch)
    # calculate the number of training iterations
    n_steps = (num_batch + 1) * n_epochs
    # determine half the size of one batch, for updating the  discriminator
    half_batch = int(n_batch / 2)
    d_history = []
    g_history = []
    # manually enumerate epochs
    for epoch in range(n_steps):
        d_loss_lst = [0] * n_critic
        i = 0
        for i in range(0, n_critic):
            # prepare real samples
            x_real, y_real = generate_real_samples(half_batch, data)
            # prepare fake examples
            x_fake, y_fake = generate_fake_samples(g_model, latent_dim, half_batch)
            # update discriminator
            d_loss_real, d_real_acc = d_model.train_on_batch(x_real, y_real)
            d_loss_fake, d_fake_acc = d_model.train_on_batch(x_fake, y_fake)
            d_loss_lst[i] = d_loss_real - d_loss_fake
        d_loss = sum(d_loss_lst)/ n_critic
        
        #Calculate Gradient Penalty        
        epsilon = tf.random.normal([half_batch, 1], 0.0, 1.0, dtype=tf.dtypes.float64)
        x_hat = epsilon * x_real + (1 - epsilon) * x_fake
        with tf.GradientTape() as t:
            t.watch(x_hat)
            d_hat = d_model(x_hat, training = True)
        gradients = t.gradient(d_hat, x_hat)
        ddx = tf.sqrt(tf.reduce_sum(gradients ** 2))
        gp = tf.reduce_mean((ddx - 1.0) ** 2)
        # Add the gradient penalty to the original discriminator loss
        d_loss_with_gp = d_loss + gp * gradient_penalty_weight
        print(pd.DataFrame(gradients))
        d_gradient = t.gradient(d_loss_with_gp, d_model.trainable_variables)
        d_optimizer.apply_gradients(zip(d_gradient, d_model.trainable_variables)) 
        
        # prepare points in latent space as input for the generator
        x_gan = generate_latent_points(latent_dim, n_batch)
        # create labels for the fake samples
        y_gan = np.ones((n_batch, 1))
        # update the generator via the discriminator's error
        g_loss = gan_model.train_on_batch(x_gan, y_gan)
        print('>%d, d=%.3f g=%.3f' % (epoch+1, d_loss,  g_loss))
        d_history.append(d_loss)
        g_history.append(g_loss)
    plot_history(d_history, g_history)
    g_model.save('trained_generated_model.h5')

但是,运行此代码时出现错误。有人可以告诉我我在做什么错误吗?

4

0 回答 0