我正在尝试通过将梯度应用于其优化器来训练鉴别器网络。但是,当我使用 tf.GradientTape 查找损失 wrt 训练变量的梯度时,会返回 None。这是训练循环:
def train_step():
#Generate noisy seeds
noise = tf.random.normal([BATCH_SIZE, noise_dim])
with tf.GradientTape() as disc_tape:
pattern = generator(noise)
pattern = tf.reshape(tensor=pattern, shape=(28,28,1))
dataset = get_data_set(pattern)
disc_loss = tf.Variable(shape=(1,2), initial_value=[[0,0]], dtype=tf.float32)
disc_tape.watch(disc_loss)
for batch in dataset:
disc_loss.assign_add(discriminator(batch, training=True))
disc_gradients = disc_tape.gradient(disc_loss, discriminator.trainable_variables)
代码说明
生成器网络从噪声中生成“模式”。然后,我通过对张量应用各种卷积从该模式生成数据集。返回的数据集是批处理的,因此我遍历数据集并通过将此批次的损失添加到总损失中来跟踪鉴别器的损失。
我所知道的
当两个变量之间没有图形连接时,tf.GradientTape 返回 None。但是损失和可训练变量之间没有图形连接吗?我相信我的错误与我如何跟踪 disc_loss tf.Variable 中的损失有关
我的问题
如何在迭代批处理数据集时跟踪损失,以便以后可以使用它来计算梯度?