0

如何在下面编写此代码单元的矢量化版本?代码应该完全矢量化,没有for循环,使用np.meshgrid和np.linspace?

def eval_on_grid_unvectorized(func, extent, numsteps):


     """Evaluates func(x1, x2) for each combination in a 2D grid.

    func: callable - function to evaluate for each grid element

    extent: tuple - grid extent as (x1min, x1max, x2min, x2max)

    numsteps: int - number of grid steps (same for each 
    dimension)

    """
    x1min, x1max, x2min, x2max = extent

    x1 = np.empty((numsteps, numsteps))
    x2 = np.empty((numsteps, numsteps))
    y  = np.empty((numsteps, numsteps))
    for i in range(numsteps):
        for j in range(numsteps):
           x1[i,j] = x1min + j*(x1max-x1min)/(numsteps-1)
           x2[i,j] = x2min + i*(x2max-x2min)/(numsteps-1)
           y[i,j] = func(x1[i,j], x2[i,j])
    return x1, x2, y
4

1 回答 1

0

您的功能,没有y; 生产

In [57]: eval_on_grid_unvectorized(None,(0,5,0,6),6)
Out[57]: 
(array([[0., 1., 2., 3., 4., 5.],
        [0., 1., 2., 3., 4., 5.],
        [0., 1., 2., 3., 4., 5.],
        [0., 1., 2., 3., 4., 5.],
        [0., 1., 2., 3., 4., 5.],
        [0., 1., 2., 3., 4., 5.]]),
 array([[0. , 0. , 0. , 0. , 0. , 0. ],
        [1.2, 1.2, 1.2, 1.2, 1.2, 1.2],
        [2.4, 2.4, 2.4, 2.4, 2.4, 2.4],
        [3.6, 3.6, 3.6, 3.6, 3.6, 3.6],
        [4.8, 4.8, 4.8, 4.8, 4.8, 4.8],
        [6. , 6. , 6. , 6. , 6. , 6. ]]))

Meshgrid 和 linspace 可以做同样的事情:

In [59]: np.meshgrid(np.linspace(0,5,6), np.linspace(0,6,6),indexing='xy')
Out[59]: 
[array([[0., 1., 2., 3., 4., 5.],
        [0., 1., 2., 3., 4., 5.],
        [0., 1., 2., 3., 4., 5.],
        [0., 1., 2., 3., 4., 5.],
        [0., 1., 2., 3., 4., 5.],
        [0., 1., 2., 3., 4., 5.]]),
 array([[0. , 0. , 0. , 0. , 0. , 0. ],
        [1.2, 1.2, 1.2, 1.2, 1.2, 1.2],
        [2.4, 2.4, 2.4, 2.4, 2.4, 2.4],
        [3.6, 3.6, 3.6, 3.6, 3.6, 3.6],
        [4.8, 4.8, 4.8, 4.8, 4.8, 4.8],
        [6. , 6. , 6. , 6. , 6. , 6. ]])]

但正如 Jérôme 指出的那样,只要func只能使用标量值,就无法进行快速的整个数组计算,就不能“矢量化”。

如果 func 类似于x+y,那么我们可以简单地将数组传递给它:

In [60]: _[0]+_[1]
Out[60]: 
array([[ 0. ,  1. ,  2. ,  3. ,  4. ,  5. ],
       [ 1.2,  2.2,  3.2,  4.2,  5.2,  6.2],
       [ 2.4,  3.4,  4.4,  5.4,  6.4,  7.4],
       [ 3.6,  4.6,  5.6,  6.6,  7.6,  8.6],
       [ 4.8,  5.8,  6.8,  7.8,  8.8,  9.8],
       [ 6. ,  7. ,  8. ,  9. , 10. , 11. ]])

我们通常调用的关键vectorizationnumpy使用已编译的 numpy 方法和运算符来编写计算。它需要真正的知识numpy; 没有一个神奇的捷径可以让你从标量 Python 计算跳到高效计算numpy

于 2022-01-23T17:19:30.510 回答