我在 numpy 中找不到内置函数,但它可以通过执行协方差矩阵 Σ = LLᵀ 的 Cholesky 分解来自行实现,然后利用以下事实:给定 iid 标准正态变量的向量 X,变换 LX + µ 具有协方差 Σ 和均值 µ。
这可以使用 eg 来实现np.linalg.cholesky()
(请注意,此功能支持批处理模式!),并且np.random.normal()
:
# cov: (*B, D, D)
# mean: (*B, D)
# result: (*S, *B, D)
L = np.linalg.cholesky(cov)
X = np.random.standard_normal((*S, *B, D, 1))
Y = (L @ X).reshape(*S, *B, D) + mean
在这里,包装在一个函数中以便于使用:
import numpy as np
def sample_batch_mvn(
mean: np.ndarray,
cov: np.ndarray,
size: "tuple | int" = (),
) -> np.ndarray:
"""
Batch sample multivariate normal distribution.
Arguments:
mean: expected values of shape (…M, D)
cov: covariance matrices of shape (…M, D, D)
size: additional batch shape (…B)
Returns: samples from the multivariate normal distributions
shape: (…B, …M, D)
It is not required that ``mean`` and ``cov`` have the same shape
prefix, only that they are broadcastable against each other.
"""
mean = np.asarray(mean)
cov = np.asarray(cov)
size = (size, ) if isinstance(size, int) else tuple(size)
shape = size + np.broadcast_shapes(mean.shape, cov.shape[:-1])
X = np.random.standard_normal((*shape, 1))
L = np.linalg.cholesky(cov)
return (L @ X).reshape(shape) + mean
现在为了测试这个函数,我们首先需要一批好的协方差矩阵。我们将生成一对来测试采样性能:
# Generate N batch of D-dimensional covariance matrices C:
N = 5000
D = 2
L = np.zeros((N, D, D))
L[(..., *np.tril_indices(D))] = \
np.random.normal(size=(N, D * (D + 1) // 2))
cov = L @ np.swapaxes(L, -1, -2)
此处用于生成协方差矩阵的方法实际上是通过对 Cholesky 因子 L 进行采样来工作的。有了这些因子的先验知识,我们当然不需要在采样函数中计算 Cholesky 分解。然而,为了测试函数的普遍适用性,我们将忘记它们,只传递协方差矩阵 C:
mean = np.zeros(2)
samples = sample_batch_mvn(mean, cov, 1000)
print(samples.shape) # (1000, 5000, 2)
在我的 PC 上对这 500 万个 2D 向量进行采样大约需要 0.4 秒。
而且,几乎与往常一样,绘图将付出相当大的努力(这里显示了 5000 个协方差矩阵中前 9 个的一些样本):
data:image/s3,"s3://crabby-images/6bf24/6bf245412473f64d09a4a0a13fab460088dd28fa" alt="样本和概率密度函数"
import scipy.stats as stats
import matplotlib.pyplot as plt
fig, axs = plt.subplots(3, 3, figsize=(9, 9))
for ax, i in zip(axs.ravel(), range(5000)):
cc = cov[i]
xsamples = samples[:100, i, 0]
ysamples = samples[:100, i, 1]
xmin = xsamples.min()
xmax = xsamples.max()
ymin = ysamples.min()
ymax = ysamples.max()
xpad = (xmax - xmin) * 0.05
ypad = (ymax - ymin) * 0.05
xlim = (xmin - xpad, xmax + xpad)
ylim = (ymin - ypad, ymax + ypad)
xs = np.linspace(*xlim, num=51)
ys = np.linspace(*ylim, num=51)
xy = np.dstack(np.meshgrid(xs, ys))
pdf = stats.multivariate_normal.pdf(xy, mean, cc)
ax.contourf(xs, ys, pdf, 33, cmap='YlGnBu')
ax.plot(xsamples, ysamples, 'r.', alpha=.6,
markeredgecolor='k', markeredgewidth=0.5)
ax.set_xlim(*xlim)
ax.set_ylim(*ylim)
plt.show()
对此的一些启发: