9

我理解自动微分的概念,但找不到任何解释 tensorflow 如何计算不可微分函数的误差梯度,例如tf.where在我的损失函数或tf.cond图表中。它工作得很好,但我想了解 tensorflow 如何通过这些节点反向传播错误,因为没有公式可以计算它们的梯度。

4

1 回答 1

7

在 的情况下tf.where,您有一个具有三个输入的函数,条件C、trueT值和 false 值F,以及一个输出Out。梯度接收一个值并且必须返回三个值。目前,没有为条件计算梯度(这几乎没有意义),所以你只需要为T和做梯度F。假设输入和输出是向量,想象C[0]一下True。然后Out[0]来自T[0],它的梯度应该传播回来。另一方面,F[0]将被丢弃,因此应将其梯度设为零。如果Out[1]False,则 的 梯度F[1]应该传播但不会传播T[1]。所以,简而言之,对于T你应该传播给定的梯度 where CisTrue并在它所在的地方使其为零False,而对于F. 如果你看一下( operation)的梯度的实现tf.whereSelect,它确实是这样的:

@ops.RegisterGradient("Select")
def _SelectGrad(op, grad):
  c = op.inputs[0]
  x = op.inputs[1]
  zeros = array_ops.zeros_like(x)
  return (None, array_ops.where(c, grad, zeros), array_ops.where(
      c, zeros, grad))

请注意,输入值本身不用于计算,这将通过产生这些输入的操作的梯度来完成。对于tf.cond代码稍微复杂一些,因为同一个操作 ( Merge) 用在不同的上下文中,里面tf.cond也用到了Switch操作。然而想法是一样的。本质上,Switch操作用于每个输入,因此被激活的输入(如果条件是True,则第一个,否则第二个)获得接收到的梯度,另一个输入获得“关闭”梯度(如None),并且不传播更进一步。

于 2018-11-08T16:51:33.193 回答