这是我根据您的描述编写的 JAX 计算:
import numpy as np
import jax.numpy as jnp
import jax
N = 10
M = 20
rng = np.random.default_rng(0)
A = jnp.array(rng.random((N,)))
B = jnp.array(rng.random((N, M)))
theta = jnp.array(rng.random(M))
def f(A, B, theta, k=3):
C = B @ theta
_, i_upper = lax.top_k(C, k)
_, i_lower = lax.top_k(-C, k)
return A[i_lower], A[i_upper]
x, y = f(A, B, theta)
dx_dtheta, dy_dtheta = jax.jacobian(f, argnums=2)(A, B, theta)
导数都为零,我相信这是正确的,因为输出值的变化不取决于 的值的变化theta
。
但是,你可能会问,这怎么可能?毕竟,theta
进入计算,如果你为 输入不同的值theta
,你会得到不同的输出。梯度怎么可能为零?
但是,您必须记住的是,差异化并不能衡量输入是否影响输出。它测量给定输入的无穷小变化时的输出变化。
让我们以一个稍微简单的函数为例:
import jax
import jax.numpy as jnp
A = jnp.array([1.0, 2.0, 3.0])
theta = jnp.array([5.0, 1.0, 3.0])
def f(A, theta):
return A[jnp.argmax(theta)]
x = f(A, theta)
dx_dtheta = jax.grad(f, argnums=1)(A, theta)
f
由于与theta
上述相同的原因,这里的微分结果全为零。为什么?如果您对 进行微小的更改theta
,通常不会影响 的排序顺序theta
。因此,您选择的条目A
不会因 theta 的微小变化而改变,因此相对于 theta 的导数为零。
现在,您可能会争辩说,在某些情况下情况并非如此:例如,如果 theta 中的两个值非常接近,那么即使是无限小的扰动一个值也肯定会改变它们各自的等级。这是真的,但是这个过程产生的梯度是不确定的(输出的变化相对于输入的变化并不平滑)。好消息是这种不连续性是一方面的:如果你在另一个方向上扰动,排名没有变化,梯度是明确定义的。为了避免未定义的梯度,大多数 autodiff 系统将隐含地使用这种更安全的导数定义来进行基于秩的计算。
结果是当你对输入进行无限小的扰动时,输出的值不会改变,这是梯度为零的另一种说法。这并不是 autodiff 的失败——它是基于 autodiff 的微分定义的正确梯度。此外,如果您尝试在这些不连续处更改为导数的不同定义,您可能希望得到的最好结果将是未定义的输出,因此导致零的定义可以说更有用和更正确。