1

我的任务是使用 jax 为这个 函数找到 a 和 b 的导数

现在,我来这里的原因是因为我对 Python 的了解不够多,而对于所讨论的课程,我们也没有被认为是 Python。

任务是:

return a tuple (dfa, dfb) such that dfa is the partial derivatives of f by a,
           and dfb is the partial derivative of f by b

现在,我能够以正常方式做到这一点:

def function(a, b):
   dfa = sym.diff((2/b)*sym.cos(a)*sym.exp(-a*a/b*b), a)
   dfb = sym.diff((2/b)*sym.cos(a)*sym.exp(-a*a/b*b), a)
   return (dfa, dfb)

但我不熟悉算法微分,使用我们给出的例子,我试过这个:

def foo():

   x = (2/b)*sym.cos(a)
   y = sym.exp(-sym.Pow(a/b,2))
   return (x*y)

def f_partial_derviatives_algo():
   return jax.grad(foo)

但我收到此错误:

无法解压不可迭代的函数对象

如果有人可以帮助了解我如何做这样的事情,将不胜感激

4

1 回答 1

0

JAX 和 sympy 不兼容。您应该使用其中一个,而不是尝试将两者结合起来。

如果您想使用 JAX 在某个值处计算此函数的偏导数,您可以编写如下代码:

import jax.numpy as jnp
from jax import grad

def f(a, b):
  return (2 / b) * jnp.cos(a) * jnp.exp(- a ** 2 / b ** 2)

df_da = grad(f, argnums=0)
df_db = grad(f, argnums=1)

print(df_da(1.0, 1.0), df_db(1.0, 1.0))
# -1.4141841 0.3975322
于 2021-06-10T14:53:22.780 回答