2

我是自动微分编程的新手,所以这可能是一个幼稚的问题。以下是我要解决的问题的简化版本。

我有两个输入数组 - 一个Asize的向量和一个shapeN的矩阵,以及一个size的参数向量。我定义了一个新数组来获得一个新的 size 向量。然后,我获取落在 的上下四分位数中的元素的索引,并使用它们创建一个新数组和。显然这两个确实依赖于,但是有可能区分wrt吗?B(N, M)thetaMC(theta) = B * thetaNCA_low(theta) = A[lower quartile indices of C]A_high(theta) = A[upper quartile indices of C]thetaA_lowA_hightheta

到目前为止,我的尝试似乎表明没有——我使用了 autograd、JAX 和 tensorflow 的 python 库,但它们都返回零梯度。(到目前为止,我尝试过的方法包括使用 argsort 或使用 提取相关子数组tf.top_k。)

我正在寻求帮助的是证明未定义导数(或无法分析计算)的证据,或者如果确实存在,则提供有关如何估计它的建议。我的最终目标是最小化某些功能f(A_low, A_high)wrt theta

4

1 回答 1

1

这是我根据您的描述编写的 JAX 计算:

import numpy as np
import jax.numpy as jnp
import jax

N = 10
M = 20

rng = np.random.default_rng(0)
A = jnp.array(rng.random((N,)))
B = jnp.array(rng.random((N, M)))
theta = jnp.array(rng.random(M))

def f(A, B, theta, k=3):
  C = B @ theta
  _, i_upper = lax.top_k(C, k)
  _, i_lower = lax.top_k(-C, k)
  return A[i_lower], A[i_upper]

x, y = f(A, B, theta)
dx_dtheta, dy_dtheta = jax.jacobian(f, argnums=2)(A, B, theta)

导数都为零,我相信这是正确的,因为输出值的变化不取决于 的值的变化theta

但是,你可能会问,这怎么可能?毕竟,theta进入计算,如果你为 输入不同的值theta,你会得到不同的输出。梯度怎么可能为零?

但是,您必须记住的是,差异化并不能衡量输入是否影响输出。它测量给定输入的无穷小变化时的输出变化

让我们以一个稍微简单的函数为例:

import jax
import jax.numpy as jnp

A = jnp.array([1.0, 2.0, 3.0])
theta = jnp.array([5.0, 1.0, 3.0])

def f(A, theta):
  return A[jnp.argmax(theta)]

x = f(A, theta)
dx_dtheta = jax.grad(f, argnums=1)(A, theta)

f由于与theta上述相同的原因,这里的微分结果全为零。为什么?如果您对 进行微小的更改theta,通常不会影响 的排序顺序theta。因此,您选择的条目A不会因 theta 的微小变化而改变,因此相对于 theta 的导数为零。

现在,您可能会争辩说,在某些情况下情况并非如此:例如,如果 theta 中的两个值非常接近,那么即使是无限小的扰动一个值也肯定会改变它们各自的等级。这是真的,但是这个过程产生的梯度是不确定的(输出的变化相对于输入的变化并不平滑)。好消息是这种不连续性是一方面的:如果你在另一个方向上扰动,排名没有变化,梯度是明确定义的。为了避免未定义的梯度,大多数 autodiff 系统将隐含地使用这种更安全的导数定义来进行基于秩的计算。

结果是当你对输入进行无限小的扰动时,输出的值不会改变,这是梯度为零的另一种说法。这并不是 autodiff 的失败——它是基于 autodiff 的微分定义的正确梯度。此外,如果您尝试在这些不连续处更改为导数的不同定义,您可能希望得到的最好结果将是未定义的输出,因此导致零的定义可以说更有用和更正确。

于 2021-12-03T16:30:13.977 回答