我的任务是使用 jax 为这个 函数找到 a 和 b 的导数
现在,我来这里的原因是因为我对 Python 的了解不够多,而对于所讨论的课程,我们也没有被认为是 Python。
任务是:
return a tuple (dfa, dfb) such that dfa is the partial derivatives of f by a,
and dfb is the partial derivative of f by b
现在,我能够以正常方式做到这一点:
def function(a, b):
dfa = sym.diff((2/b)*sym.cos(a)*sym.exp(-a*a/b*b), a)
dfb = sym.diff((2/b)*sym.cos(a)*sym.exp(-a*a/b*b), a)
return (dfa, dfb)
但我不熟悉算法微分,使用我们给出的例子,我试过这个:
def foo():
x = (2/b)*sym.cos(a)
y = sym.exp(-sym.Pow(a/b,2))
return (x*y)
def f_partial_derviatives_algo():
return jax.grad(foo)
但我收到此错误:
无法解压不可迭代的函数对象
如果有人可以帮助了解我如何做这样的事情,将不胜感激