0

当我增加图像时,它给了我这样的错误,如下面的截图所示: 错误描述 下面给出的训练代码

def train(get_batches, data_shape, LR_G = 2e-4, LR_D = 0.0005):
input_images, input_z, lr_G, lr_D = model_inputs(data_shape[1:], NOISE_SIZE)
d_loss, g_loss = model_loss(input_images, input_z, data_shape[3])
d_opt, g_opt = model_optimizers(d_loss, g_loss)
generator_epoch_loss = 0
train_d_losses = []
train_g_losses = []
generator_epoch_loss = 999

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    epoch = 0
    iteration = 0
    d_losses = []
    g_losses = []
    
    for epoch in tqdm(range(EPOCHS)):        
        epoch += 1
        start_time = time.time()
            
        for batch_images in get_batches:
            iteration += 1
            batch_z = np.random.uniform(-1, 1, size=(BATCH_SIZE, NOISE_SIZE))
            _ = sess.run(d_opt, feed_dict={input_images: batch_images, input_z: batch_z, lr_D: LR_D})
            _ = sess.run(g_opt, feed_dict={input_images: batch_images, input_z: batch_z, lr_G: LR_G})
            d_losses.append(d_loss.eval({input_z: batch_z, input_images: batch_images}))
            g_losses.append(g_loss.eval({input_z: batch_z}))

        summarize_epoch(epoch, time.time()-start_time, sess, d_losses, g_losses, input_z, data_shape)
        minibatch_size = int(data_shape[0]//BATCH_SIZE)
        generator_epoch_loss = np.mean(g_losses[-minibatch_size:])
        train_d_losses.append(np.mean(d_losses[-minibatch_size:]))
        train_g_losses.append(np.mean(g_losses[-minibatch_size:]))
        
        if epoch == EPOCHS:
            generate (sess, input_z, out_channel_dim=3)
        
fig, ax = plt.subplots()
plt.plot(train_d_losses, label='Discriminator', alpha=0.5)
plt.plot(train_g_losses, label='Generator', alpha=0.5)
plt.title("Training Losses")
plt.legend()
plt.savefig('train_losses.png')
plt.show()
plt.close()

代码显示在下面的屏幕截图中,我认为它给了我错误。 代码

火车

总结时代

测试

展示样品

请帮我解决这个问题

4

0 回答 0