我正在尝试在 TPU 上训练 GAN,因此我一直在使用 TPUEstimator 类和随附的模型函数来尝试实现 WGAN 训练循环。我正在尝试tf.cond
将 TPUEstimatorSpec 的两个训练操作合并为:
opt = tf.cond(
tf.equal(tf.mod(tf.train.get_or_create_global_step(),
CRITIC_UPDATES_PER_GEN_UPDATE+1), CRITIC_UPDATES_PER_GEN_UPDATE+1),
lambda: gen_opt,
lambda: critic_opt
)
gen_opt
并且critic_opt
是我正在使用的优化器的最小化功能,也设置为更新全局步骤。CRITIC_UPDATES_PER_GEN_UPDATE
是一个 Python 常量,它是 WGAN 训练的一部分。我尝试使用 找到 GAN 模型tf.cond
,但所有模型都使用tf.group
,我不能使用它,因为您需要比生成器优化更多次批评者。但是,每次运行 100 个批次,全局步长根据检查点编号增加 200。我的模型是否仍在正确训练,或者tf.cond
不应该以这种方式用于训练 GAN?