我对如何在 jax 中计算高阶多元导数感到困惑。
例如,你如何计算 d^2f / dx dy
def f(x, y):
return jnp.sin(jnp.dot(x, y.T))
其中 x, y 在 R^n, n >= 1 中?
我一直在尝试jax.jvp
and jax.partial
,但我没有任何成功。
由于x
andy
是向量值并且是标量,我相信您可以通过将and函数与适当的 argnumsf(x, y)
组合来计算您所追求的:jax.jacfwd
jax.jacrev
import jax.numpy as jnp
from jax import jacfwd, jacrev
def f(x, y):
return jnp.sin(jnp.dot(x, y.T))
d2f_dxdy = jacfwd(jacrev(f, argnums=1), argnums=0)
x = jnp.arange(4.0)
y = jnp.ones(4)
print(d2f_dxdy(x, y))
# DeviceArray([[0.96017027, 0. , 0. , 0. ],
# [0.2794155 , 1.2395858 , 0.2794155 , 0.2794155 ],
# [0.558831 , 0.558831 , 1.5190012 , 0.558831 ],
# [0.83824646, 0.83824646, 0.83824646, 1.7984167 ]],
# dtype=float32)