1

我对如何在 jax 中计算高阶多元导数感到困惑。

例如,你如何计算 d^2f / dx dy

def f(x, y):
     return jnp.sin(jnp.dot(x, y.T))

其中 x, y 在 R^n, n >= 1 中?

我一直在尝试jax.jvpand jax.partial,但我没有任何成功。

4

1 回答 1

1

由于xandy是向量值并且是标量,我相信您可以通过将and函数与适当的 argnumsf(x, y)组合来计算您所追求的:jax.jacfwdjax.jacrev

import jax.numpy as jnp
from jax import jacfwd, jacrev

def f(x, y):
     return jnp.sin(jnp.dot(x, y.T))

d2f_dxdy = jacfwd(jacrev(f, argnums=1), argnums=0)
  
x = jnp.arange(4.0)
y = jnp.ones(4)

print(d2f_dxdy(x, y))

# DeviceArray([[0.96017027, 0.        , 0.        , 0.        ],
#              [0.2794155 , 1.2395858 , 0.2794155 , 0.2794155 ],
#              [0.558831  , 0.558831  , 1.5190012 , 0.558831  ],
#              [0.83824646, 0.83824646, 0.83824646, 1.7984167 ]],
#             dtype=float32)
于 2020-12-13T07:43:49.593 回答