我正在尝试从具有特定标准偏差和均值的高斯中采样,我知道以下函数是从均值为零且标准偏差等于 1 的高斯中采样:
import jax
from jax import random
key = random.PRNGKey(0)
mu = 20
std = 4
x1 = jax.random.normal(key, (1000,))
我可以通过这样做来调整平均值:x1 = x1 + mu
,但是如何调整标准偏差?
这
x1 = std * x1 + mu
会给你想要的
以这种方式创建您的示例:
x1 = mu + std * jax.random.normal(key, (1000,))
如果这样做,样本的直方图将遵循预期分布:
import jax
from jax import random
from jax.scipy.stats import norm
import matplotlib.pyplot as plt
key = random.PRNGKey(0)
mu = 20
std = 4
x1 = mu + std * jax.random.normal(key, (1000,))
plt.hist(x1, bins=50, density=True)
x = jnp.linspace(5, 35, 100)
y = norm.pdf(x, loc=mu, scale=std)
plt.plot(x, y)