我已经为将运行 4000 个 epoch 的通用对抗网络编写了代码,但是,在 2000 个 epoch 之后 - 模型编译时间和内存使用变得非常低效并且代码运行非常缓慢。我想让我的代码内存高效。
基于以下两个帖子。我相信答案是clear_session
在每个时代结束时使用:
https://github.com/keras-team/keras/issues/2828
https://github.com/keras-team/keras/issues/6457
但是如果我clear_session
在每个 epoch 结束时使用,我需要在这样做之前将鉴别器和生成器的权重保存到磁盘。此策略仅适用于第一个 epoch,之后我不断收到ValueError: Tensor("training_1/Adam/Const:0", shape=(), dtype=float32) must be from the same graph as Tensor("sub:0", shape=(), dtype=float32).
由已建立的 tensorflow 图的停止和重新启动引起的错误。我也收到错误Cannot interpret feed_dict key as Tensor: Tensor Tensor("conv1d_1_input:0", shape=(?, 750, 1), dtype=float32) is not an element of this graph.
from keras import backend as K
discriminator=load_model('discriminator')
discriminator.trainable = False
gen_loss=[]
dis_loss=[]
epochs = 4000
batch_size = 100
save_interval = 100
valid = np.ones((batch_size, 1))
fake = np.zeros((batch_size, 1))
for epoch in range(epochs):
idx = np.random.randint(0, train_X.shape[0], batch_size)
imgs = train_X[idx]
# Sample noise and generate a batch of new images
noise = np.random.normal(0, 1, (batch_size, int(train_X.shape[1]/4)))
noise = noise.reshape(noise.shape[0], noise.shape[1], 1)
generator = load_model('9_heterochromatin', 'generator', '1000')
gen_imgs = generator.predict(noise)
combined = add_layers(generator, discriminator, len(discriminator.layers))
combined.compile(loss='binary_crossentropy', optimizer=optimizer)
# Train the discriminator (real classified as ones and generated as zeros)
d_loss_real = discriminator.train_on_batch(imgs, valid)
d_loss_fake = discriminator.train_on_batch(gen_imgs, fake)
d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)
# Train the generator (wants discriminator to mistake images as real)
g_loss = combined.train_on_batch(noise, valid)
generator = add_layers(Sequential(), combined, first_half_length)
save_model(generator, '9_heterochromatin', 'generator', '1000')
gen_loss.append(g_loss)
dis_loss.append(d_loss[0])
# Plot the progress
print ("%d [D loss: %f, acc.: %.2f%%] [G loss: %f]" % (epoch, d_loss[0], 100*d_loss[1], g_loss))
# If at save interval => save generated image samples
if epoch % save_interval == 0:
save_imgs(epoch, gen_loss, dis_loss)
K.clear_session()
我正在尝试制作一个内存高效的 GAN,该 GAN 在使用clear_session
以防止内存泄漏时在随后的每个 epoch 中保存和重新加载学习的权重的基础上运行。有谁知道如何在没有冲突的张量流图的情况下实现这一点。