我正在尝试基于 Wassertein-GP 方法为表格 GAN 编写代码。我正在使用带有 keras 库的 python 来编写代码。为此,在定义了 Critic 和 Generator 函数之后,我编写了如下的 train 函数:
def train(g_model, d_model, gan_model, latent_dim, data, n_epochs=20, n_batch=500, n_critic = 5):
num_batch = int(data.shape[0] / n_batch)
# calculate the number of training iterations
n_steps = (num_batch + 1) * n_epochs
# determine half the size of one batch, for updating the discriminator
half_batch = int(n_batch / 2)
d_history = []
g_history = []
# manually enumerate epochs
for epoch in range(n_steps):
d_loss_lst = [0] * n_critic
i = 0
for i in range(0, n_critic):
# prepare real samples
x_real, y_real = generate_real_samples(half_batch, data)
# prepare fake examples
x_fake, y_fake = generate_fake_samples(g_model, latent_dim, half_batch)
# update discriminator
d_loss_real, d_real_acc = d_model.train_on_batch(x_real, y_real)
d_loss_fake, d_fake_acc = d_model.train_on_batch(x_fake, y_fake)
d_loss_lst[i] = d_loss_real - d_loss_fake
d_loss = sum(d_loss_lst)/ n_critic
#Calculate Gradient Penalty
epsilon = tf.random.normal([half_batch, 1], 0.0, 1.0, dtype=tf.dtypes.float64)
x_hat = epsilon * x_real + (1 - epsilon) * x_fake
with tf.GradientTape() as t:
t.watch(x_hat)
d_hat = d_model(x_hat, training = True)
gradients = t.gradient(d_hat, x_hat)
ddx = tf.sqrt(tf.reduce_sum(gradients ** 2))
gp = tf.reduce_mean((ddx - 1.0) ** 2)
# Add the gradient penalty to the original discriminator loss
d_loss_with_gp = d_loss + gp * gradient_penalty_weight
print(pd.DataFrame(gradients))
d_gradient = t.gradient(d_loss_with_gp, d_model.trainable_variables)
d_optimizer.apply_gradients(zip(d_gradient, d_model.trainable_variables))
# prepare points in latent space as input for the generator
x_gan = generate_latent_points(latent_dim, n_batch)
# create labels for the fake samples
y_gan = np.ones((n_batch, 1))
# update the generator via the discriminator's error
g_loss = gan_model.train_on_batch(x_gan, y_gan)
print('>%d, d=%.3f g=%.3f' % (epoch+1, d_loss, g_loss))
d_history.append(d_loss)
g_history.append(g_loss)
plot_history(d_history, g_history)
g_model.save('trained_generated_model.h5')
但是,运行此代码时出现错误。有人可以告诉我我在做什么错误吗?