在CycleGAN 论文中提到了鉴别器的历史池。因此,我们保留了来自生成器的最后 50 个样本,并将它们提供给鉴别器。如果没有历史,这很简单,我们可以利用tf.data.Dataset
迭代器将数据插入网络。但是有了历史池,我不知道如何使用tf.data.Dataset
api。训练循环内的代码看起来像
fx, fy = sess.run(model_ops['fakes'], feed_dict={
self.cur_x: cur_x,
self.cur_y: cur_y,
})
cur_x, cur_y = sess.run([self.X_feed.feed(), self.Y_feed.feed()])
feeder_dict = {
self.cur_x: cur_x,
self.cur_y: cur_y,
self.prev_fake_x: x_pool.query(fx, step),
self.prev_fake_y: y_pool.query(fy, step),
}
# self.cur_x, self.cur_y, self.prev_fake_x, self.prev_fake_y are just placeholders
# x_pool and y_pool are simple wrappers for random sampling from the history pool and saving new images to the pool
for _ in range(dis_train):
sess.run(model_ops['train']['dis'], feed_dict=feeder_dict)
for _ in range(gen_train):
sess.run(model_ops['train']['gen'], feed_dict=feeder_dict)
令我困扰的是代码效率低下,例如在训练期间不可能像tf.data
API 的预取那样预加载下一批,但我看不到任何使用tf.data
API 的方法。它是否提供某种历史池,我可以将其用于预取并通常优化数据加载模型?此外,当我在鉴别器训练 op 和生成器训练 op 之间有一定比例时,也会出现类似的问题。例如,如果我想每 1 步鉴别器运行 2 步生成器训练操作,可以使用相同的数据来完成吗?因为使用tf.data
API,每次调用 sess.run 时都会从迭代器中提取新样本。
有什么方法可以正确有效地实施吗?