0

假设我们有一个 numpy 数组 v

v=np.array([3, 5])

现在我们使用下面的代码来找到一个新的向量说 w

v1=np.array(range(v[0]+1))
v2=np.array(range(v[1]+1))
w=np.array(list(itertools.product(v1,v2)))

所以 w 看起来像这样,

array([[0, 0],
   [0, 1],
   [0, 2],
   [0, 3],
   [0, 4],
   [0, 5],
   [1, 0],
   [1, 1],
   [1, 2],
   [1, 3],
   [1, 4],
   [1, 5],
   [2, 0],
   [2, 1],
   [2, 2],
   [2, 3],
   [2, 4],
   [2, 5],
   [3, 0],
   [3, 1],
   [3, 2],
   [3, 3],
   [3, 4],
   [3, 5]])

现在,我们需要知道每对中的第一个元素遵循二项式分布 Bin(v[0], 0.1) 并且每对的第二个元素遵循二项式分布 Bin(v [1], 0.05)。一种方法是通过这一个班轮

  import scipy.stats as ss
  prob_vector=np.array(list((ss.binom.pmf(i[0],v[0], 0.1) * ss.binom.pmf(i[1],v[1], 0.05)) for i in w))

输出:

array([5.64086303e-01, 1.48443764e-01, 1.56256594e-02, 8.22403125e-04,
           2.16421875e-05, 2.27812500e-07, 1.88028768e-01, 4.94812547e-02,
           5.20855312e-03, 2.74134375e-04, 7.21406250e-06, 7.59375000e-08,
           2.08920853e-02, 5.49791719e-03, 5.78728125e-04, 3.04593750e-05,
           8.01562500e-07, 8.43750000e-09, 7.73780938e-04, 2.03626563e-04,
           2.14343750e-05, 1.12812500e-06, 2.96875000e-08, 3.12500000e-10])

但是计算需要太多时间,特别是因为我正在迭代几个 v 向量!

有没有一种有效的方法来计算 prob_vector?

谢谢

4

1 回答 1

1

您正在重做很多 pmf 调用,并且在 Python 端而不是 numpy 端做了很多。我们可以通过计算 v1 和 v2 数组来保存这些计算,然后将它们相乘。

import numpy as np
import scipy.stats as ss
import itertools

def orig(x, y):
    v = np.array([x, y])
    v1 =np.array(range(v[0]+1))
    v2=np.array(range(v[1]+1))
    w=np.array(list(itertools.product(v1,v2)))
    prob_vector=np.array(list((ss.binom.pmf(i[0],v[0], 0.1) * ss.binom.pmf(i[1],v[1], 0.05)) for i in w))
    return prob_vector

def faster(x, y):
    b0 = ss.binom.pmf(np.arange(x+1), x, 0.1)
    b1 = ss.binom.pmf(np.arange(y+1), y, 0.05)
    prob_array = b0[:, None] * b1
    prob_vector = prob_array.ravel()
    return prob_vector

这给了我:

In [61]: %timeit orig(3, 5)
4.46 ms ± 82.5 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

In [62]: %timeit faster(3, 5)
192 µs ± 4.33 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)

In [63]: %timeit orig(30, 50)
311 ms ± 24.7 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

In [64]: %timeit faster(30, 50)
209 µs ± 8.43 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)

In [65]: (orig(30, 50) == faster(30, 50)).all()
Out[65]: True
于 2018-06-18T21:21:17.367 回答