我正在练习使用 TensorFlow 的custom_gradient
装饰器,并尝试定义一个简单的 ReLU。人们会认为这就像将梯度定义为何1
时x > 0
和0
否则一样简单。但是,以下代码不会产生与 ReLU 相同的梯度:
@tf.custom_gradient
def relu(x):
def grad(dy):
return tf.cond(tf.reshape(x, []) > 0,
lambda: tf.cast(tf.reshape(1, dy.shape), tf.float32),
lambda: tf.cast(tf.reshape(0, dy.shape), tf.float32))
return tf.nn.relu(x), grad
有人可以向我解释为什么这个 ReLU 梯度的标准定义不会产生与以下相同的性能:
@tf.custom_gradient
def relu(x):
def grad(dy):
return dy
return tf.nn.relu(x), grad