2

考虑实现一个需要临时变量实例化的自定义损失函数。如果我们需要实现自定义梯度,TF 期望梯度函数有一个额外的输出,此时梯度的分量应该与损失函数的输入一样多。也就是说,如果我的理解是正确的。任何更正表示赞赏。

链接相关的 github 问题,其中包含一个最小工作示例 (MWE) 和其他调试信息:https ://github.com/tensorflow/tensorflow/issues/31945

从 github 帖子复制粘贴的 MWE 是:

import tensorflow as tf
# from custom_gradient import custom_gradient  # my corrected version
from tensorflow import custom_gradient


def layer(t, name):
    var = tf.Variable(1.0, dtype=tf.float32, use_resource=True, name=name)
    return t * var


@custom_gradient
def custom_gradient_layer(t):
    result = layer(t, name='outside')

    def grad(*grad_ys, variables=None):
        assert variables is not None
        print(variables)
        grads = tf.gradients(
            layer(t, name='inside'),
            [t, *variables],
            grad_ys=grad_ys,
        )
        grads = (grads[:1], grads[1:])
        return grads

    return result, grad

哪个会抛出ValueError: not enough values to unpack....

如果我的理解是正确的,通常对于伴随方法(反向模式 autodiff),前向传递构建表达式树,而对于反向传递,我们评估梯度,梯度函数是值乘以我们函数的偏导数' d 取关于的导数,它可能是一个复合函数。如果需要,我可以发布参考。

因此,对于一个输入变量,我们将对梯度进行一次评估。在这里,TF 期望 2,即使我们只有一个输入变量,因为临时变量在某些情况下是不可避免的。

我的 MWE 伪代码是这样的:

@tf.custom_gradient
def custom_loss(in):

    temp = tf.Variable(tf.zeros([2 * N - 1]), dtype = tf.float32)

    ## compute loss function
    ...

     def grad(df):
         grad = df * partial_derivative
         return grad

    return loss, grad

安德烈

4

1 回答 1

1

我有同样的问题。我发现添加 trainable=False 为我解决了这个问题。例如以下

import tensorflow as tf
@tf.custom_gradient
def custom_loss(x):

    temp = tf.Variable(1., dtype = tf.float32)

    loss = x*temp

     def grad(dL):
         grad = dL * temp
         return grad

    return loss, grad

给我错误“TypeError:如果将@custom_gradient 与使用变量的函数一起使用,则 grad_fn 必须接受关键字参数'variables'。”

但如果我这样做,我不会出错

import tensorflow as tf
@tf.custom_gradient
def custom_loss(x):

    temp = tf.Variable(1., dtype = tf.float32, trainable=False)

    loss = x*temp

     def grad(dL):
         grad = dL * temp
         return grad

    return loss, grad

希望这对您或其他人有所帮助。

于 2020-07-09T02:35:45.493 回答