5

I'm relatively new to Python and I've got a nested for loop. Since the for loops take a while to run, I'm trying to figure out a way to vectorize this code so it can run faster.

In this case, coord is a 3-dimensional array where coord[x, 0, 0] and coord[x, 0, 1] are integers and coord[x, 0, 2] is either 0 or 1. H is a SciPy sparse matrix and x_dist, y_dist, z_dist, and a are all floats.

# x_dist, y_dist, and z_dist are floats
# coord is a num x 1 x 3 numpy array where num can go into the hundreds of thousands
num = coord.shape[0]    
H = sparse.lil_matrix((num, num))
for i in xrange(num):
    for j in xrange(num):
        if (np.absolute(coord[i, 0, 0] - coord[j, 0, 0]) <= 2 and
                (np.absolute(coord[i, 0, 1] - coord[j, 0, 1]) <= 1)):

            x = ((coord[i, 0, 0] * x_dist + coord[i, 0, 2] * z_dist) -
                 (coord[j, 0, 0] * x_dist + coord[j, 0, 2] * z_dist))

            y = (coord[i, 0, 1] * y_dist) - (coord[j, 0, 1] * y_dist)

            if a - 0.5 <= np.sqrt(x ** 2 + y ** 2) <= a + 0.5:
                H[i, j] = -2.7

I've also read that broadcasting with NumPy, while much faster, can lead to large amounts of memory usage from temporary arrays. Would it be better to go the vectorization route or try and use something like Cython?

4

2 回答 2

7

这就是我将如何矢量化您的代码,稍后将讨论一些警告:

import numpy as np
import scipy.sparse as sps

idx = ((np.abs(coord[:, 0, 0] - coord[:, 0, 0, None]) <= 2) &
       (np.abs(coord[:, 0, 1] - coord[:, 0, 1, None]) <= 1))

rows, cols = np.nonzero(idx)
x = ((coord[rows, 0, 0]-coord[cols, 0, 0]) * x_dist +
     (coord[rows, 0, 2]-coord[cols, 0, 2]) * z_dist)
y = (coord[rows, 0, 1]-coord[cols, 0, 1]) * y_dist
r2 = x*x + y*y

idx = ((a - 0.5)**2 <= r2) & (r2 <= (a + 0.5)**2)

rows, cols = rows[idx], cols[idx]
data = np.repeat(2.7, len(rows))

H = sps.coo_matrix((data, (rows, cols)), shape=(num, num)).tolil()

正如您所指出的,问题将出现在第一个idx数组中,因为它将是形状,因此如果“成百上千” (num, num),它可能会将您的记忆粉碎。num

一种潜在的解决方案是将您的问题分解为可管理的块。如果您有一个 100,000 个元素的数组,您可以将其拆分为 100 个包含 1,000 个元素的块,并为 10,000 个块组合中的每一个运行上面代码的修改版本。您只需要一个 1,000,000 元素idx数组(您可以预先分配和重用它以获得更好的性能),并且您将拥有一个只有 10,000 次迭代的循环,而不是当前实现的 10,000,000,000 次。这是一种穷人的并行化方案,如果你有一台多核机器,你实际上可以通过并行处理其中的几个块来改进它。

于 2013-07-22T16:58:32.580 回答
2

计算的性质使得使用我熟悉的 numpy 方法进行矢量化变得很困难。我认为在速度和内存使用方面的最佳解决方案是 cython。但是,您可以使用numba获得一些加速。这是一个示例(请注意,通常您将autojit其用作装饰器):

import numpy as np
from scipy import sparse
import time
from numba.decorators import autojit
x_dist=.5
y_dist = .5
z_dist = .4
a = .6
coord = np.random.normal(size=(1000,1000,1000))

def run(coord, x_dist,y_dist, z_dist, a):
    num = coord.shape[0]    
    H = sparse.lil_matrix((num, num))
    for i in xrange(num):
        for j in xrange(num):
            if (np.absolute(coord[i, 0, 0] - coord[j, 0, 0]) <= 2 and
                    (np.absolute(coord[i, 0, 1] - coord[j, 0, 1]) <= 1)):

                x = ((coord[i, 0, 0] * x_dist + coord[i, 0, 2] * z_dist) -
                     (coord[j, 0, 0] * x_dist + coord[j, 0, 2] * z_dist))

                y = (coord[i, 0, 1] * y_dist) - (coord[j, 0, 1] * y_dist)

                if a - 0.5 <= np.sqrt(x ** 2 + y ** 2) <= a + 0.5:
                    H[i, j] = -2.7
    return H

runaj = autojit(run)

t0 = time.time()
run(coord,x_dist,y_dist, z_dist, a)
t1 = time.time()
print 'First Original Runtime:', t1 - t0

t0 = time.time()
run(coord,x_dist,y_dist, z_dist, a)
t1 = time.time()
print 'Second Original Runtime:', t1 - t0

t0 = time.time()
run(coord,x_dist,y_dist, z_dist, a)
t1 = time.time()
print 'Third Original Runtime:', t1 - t0

t0 = time.time()
runaj(coord,x_dist,y_dist, z_dist, a)
t1 = time.time()
print 'First Numba Runtime:', t1 - t0

t0 = time.time()
runaj(coord,x_dist,y_dist, z_dist, a)
t1 = time.time()
print 'Second Numba Runtime:', t1 - t0

t0 = time.time()
runaj(coord,x_dist,y_dist, z_dist, a)
t1 = time.time()
print 'Third Numba Runtime:', t1 - t0

我得到这个输出:

First Original Runtime: 21.3574919701
Second Original Runtime: 15.7615520954
Third Original Runtime: 15.3634860516
First Numba Runtime: 9.87108802795
Second Numba Runtime: 9.32944011688
Third Numba Runtime: 9.32300305367
于 2013-07-22T16:54:09.713 回答