Jax 是否支持将导数作为向量值变量的索引?考虑这个例子(a
向量/数组在哪里):
def test_func(a):
return a[0]**a[1]
我可以将参数编号传递给grad(..)
,但我似乎无法像上面的示例那样传递向量值参数的索引。我尝试传递一个元组的元组,即
grad(test_func, argnums=((0,),))
但这不起作用。
没有内置变换可以对数组的某些元素进行渐变,但是您可以通过将数组拆分为单个元素的包装函数直接执行此操作;例如:
import jax
import jax.numpy as jnp
def test_func(a):
return a[0]**a[1]
a = jnp.array([1.0, 2.0])
fgrad = jax.grad(lambda *args: test_func(jnp.array(args)), argnums=0)
print(fgrad(*a))
# 2.0
如果您想对所有输入单独进行渐变(返回关于每个条目的渐变向量),您可以使用jax.jacobian
:
print(jax.jacobian(test_func)(a))
# [2. 0.]