1

我想使用 resnet 在数据集上实现“全批”梯度下降算法。但是,由于在全批次梯度中,每次迭代都需要计算所有训练点的梯度,所以当我使用 tf.GradinetTape 时,它​​会引发 OOM 错误。这是一个计算梯度的函数:

@tf.function
def compute_gradients(training_ds):  # this function computer gradients for other functions
    image, label  = training_ds
    with tf.GradientTape() as tape:
      predictions = model(image)
      loss = loss_fn(label, predictions)    
    grad = tape.gradient(loss, model.trainable_variables)
return grad

但是,由于网络非常大,我遇到了 OOM 错误。如何有效地计算全批次梯度下降的梯度?

4

0 回答 0