我正在尝试使用 JAX 从多元正态分布生成样本:
import jax
import jax.numpy as jnp
import numpy as np
key = random.PRNGKey(0)
cov = np.array([[1.2, 0.4], [0.4, 1.0]])
mean = np.array([3,-1])
x1,x2 = jax.random.multivariate_normal(key, mean, cov, 5000).T
但是,当我运行代码时,出现以下错误:
TypeError Traceback (most recent call last)
<ipython-input-25-1397bf923fa4> in <module>()
2 cov = np.array([[1.2, 0.4], [0.4, 1.0]])
3 mean = np.array([3,-1])
----> 4 x1,x2 = jax.random.multivariate_normal(key, mean, cov, 5000).T
1 frames
/usr/local/lib/python3.6/dist-packages/jax/core.py in canonicalize_shape(shape)
1159 "got {}.")
1160 if any(isinstance(x, Tracer) and isinstance(get_aval(x), ShapedArray)
-> 1161 and not isinstance(get_aval(x), ConcreteArray) for x in shape):
1162 msg += ("\nIf using `jit`, try using `static_argnums` or applying `jit` to "
1163 "smaller subfunctions.")
TypeError: 'int' object is not iterable
我不确定问题是什么,因为相同的语法适用于 Numpy 中的等效函数