1
N= [-7.12843079e+02, -1.39668296e+02, -6.01626070e+01, -3.51688015e+01]
jax.scipy.special.digamma(N)
TypeError: digamma does not accept dtype complex64. Accepted dtypes are subtypes of floating.

我正在尝试使用 jax.scipy.special.digamma 以复数计算 digamma,但是,即使此包的文档说它可能很复杂,它仍然给我这个错误,这是文档所说的:

参数: z (array_like) – 实数或复数参数。

知道如何解决这个问题吗?或者是否有其他方法,例如其他库或其他包,允许我使用复数来计算 digamma 函数!?

4

1 回答 1

1

我有一个类似的问题。这就是我解决它的方法。漫长的道路。我知道有人会给出一个简洁的方法。

从定义psi 函数

我从这里使用了 Gamma 函数。确保输出在 JAX 中,否则您将无法使用 GRAD;

import jax.numpy as jnp
from jax import grad

def gamma_func_numeric(z):
    g = 7
    z -= 1
    x = lanczos_coef[0]

    for i in range(1, g+2):
        x +=   lanczos_coef[i]/(z+i)    
        t = z + g  + 0.5
   return jnp.sqrt(2*jnp.pi)*jnp.power(t,(z+0.5))*jnp.exp(-t)*x

eta is a complex number

def psi_numeric(eta):
    gamma_prime = grad(gamma_func_numeric, holomorphic=True)(eta)
    gamma = gamma_func_numeric(eta)
    return gamma_prime/gamma

让我们比较一下我们的结果:

scipy.special.psi(1.+2j)=(0.7145915153739777+1.3208072826422304j)

psi_numeric(1.+2j) = DeviceArray(0.7145916+1.3208078j, dtype=complex64)
于 2021-07-14T09:40:28.967 回答