2

我已经为将运行 4000 个 epoch 的通用对抗网络编写了代码,但是,在 2000 个 epoch 之后 - 模型编译时间和内存使用变得非常低效并且代码运行非常缓慢。我想让我的代码内存高效。

基于以下两个帖子。我相信答案是clear_session在每个时代结束时使用:

https://github.com/keras-team/keras/issues/2828

https://github.com/keras-team/keras/issues/6457

但是如果我clear_session在每个 epoch 结束时使用,我需要在这样做之前将鉴别器和生成器的权重保存到磁盘。此策略仅适用于第一个 epoch,之后我不断收到ValueError: Tensor("training_1/Adam/Const:0", shape=(), dtype=float32) must be from the same graph as Tensor("sub:0", shape=(), dtype=float32).由已建立的 tensorflow 图的停止和重新启动引起的错误。我也收到错误Cannot interpret feed_dict key as Tensor: Tensor Tensor("conv1d_1_input:0", shape=(?, 750, 1), dtype=float32) is not an element of this graph.

from keras import backend as K

discriminator=load_model('discriminator')
discriminator.trainable = False
gen_loss=[]
dis_loss=[]
epochs = 4000
batch_size = 100
save_interval = 100
valid = np.ones((batch_size, 1))
fake = np.zeros((batch_size, 1))
for epoch in range(epochs):
    idx = np.random.randint(0, train_X.shape[0], batch_size)
    imgs = train_X[idx]
    # Sample noise and generate a batch of new images
    noise = np.random.normal(0, 1, (batch_size, int(train_X.shape[1]/4)))
    noise = noise.reshape(noise.shape[0], noise.shape[1], 1)
    generator = load_model('9_heterochromatin', 'generator', '1000')
    gen_imgs = generator.predict(noise)
    combined = add_layers(generator, discriminator, len(discriminator.layers))
    combined.compile(loss='binary_crossentropy', optimizer=optimizer)
    # Train the discriminator (real classified as ones and generated as zeros)
    d_loss_real = discriminator.train_on_batch(imgs, valid)
    d_loss_fake = discriminator.train_on_batch(gen_imgs, fake)
    d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)
    # Train the generator (wants discriminator to mistake images as real)
    g_loss = combined.train_on_batch(noise, valid)
    generator = add_layers(Sequential(), combined, first_half_length)
    save_model(generator, '9_heterochromatin', 'generator', '1000')
    gen_loss.append(g_loss)
    dis_loss.append(d_loss[0])
    # Plot the progress
    print ("%d [D loss: %f, acc.: %.2f%%] [G loss: %f]" % (epoch, d_loss[0], 100*d_loss[1], g_loss))
    # If at save interval => save generated image samples
    if epoch % save_interval == 0:
        save_imgs(epoch, gen_loss, dis_loss)
    K.clear_session()

我正在尝试制作一个内存高效的 GAN,该 GAN 在使用clear_session以防止内存泄漏时在随后的每个 epoch 中保存和重新加载学习的权重的基础上运行。有谁知道如何在没有冲突的张量流图的情况下实现这一点。

4

0 回答 0