当我增加图像时,它给了我这样的错误,如下面的截图所示: 错误描述 下面给出的训练代码
def train(get_batches, data_shape, LR_G = 2e-4, LR_D = 0.0005):
input_images, input_z, lr_G, lr_D = model_inputs(data_shape[1:], NOISE_SIZE)
d_loss, g_loss = model_loss(input_images, input_z, data_shape[3])
d_opt, g_opt = model_optimizers(d_loss, g_loss)
generator_epoch_loss = 0
train_d_losses = []
train_g_losses = []
generator_epoch_loss = 999
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
epoch = 0
iteration = 0
d_losses = []
g_losses = []
for epoch in tqdm(range(EPOCHS)):
epoch += 1
start_time = time.time()
for batch_images in get_batches:
iteration += 1
batch_z = np.random.uniform(-1, 1, size=(BATCH_SIZE, NOISE_SIZE))
_ = sess.run(d_opt, feed_dict={input_images: batch_images, input_z: batch_z, lr_D: LR_D})
_ = sess.run(g_opt, feed_dict={input_images: batch_images, input_z: batch_z, lr_G: LR_G})
d_losses.append(d_loss.eval({input_z: batch_z, input_images: batch_images}))
g_losses.append(g_loss.eval({input_z: batch_z}))
summarize_epoch(epoch, time.time()-start_time, sess, d_losses, g_losses, input_z, data_shape)
minibatch_size = int(data_shape[0]//BATCH_SIZE)
generator_epoch_loss = np.mean(g_losses[-minibatch_size:])
train_d_losses.append(np.mean(d_losses[-minibatch_size:]))
train_g_losses.append(np.mean(g_losses[-minibatch_size:]))
if epoch == EPOCHS:
generate (sess, input_z, out_channel_dim=3)
fig, ax = plt.subplots()
plt.plot(train_d_losses, label='Discriminator', alpha=0.5)
plt.plot(train_g_losses, label='Generator', alpha=0.5)
plt.title("Training Losses")
plt.legend()
plt.savefig('train_losses.png')
plt.show()
plt.close()
代码显示在下面的屏幕截图中,我认为它给了我错误。 代码
请帮我解决这个问题