0

数值积分所花费的时间比我预期的要长得多。我想知道我在网格上实现迭代的方式是否可能是一个促成因素。我的代码如下所示:

import numpy as np
import itertools as it

U = np.linspace(0, 2*np.pi)
V = np.linspace(0, np.pi)

for (u, v) in it.product(U,V):
    # values = computation on each grid point, does not call any outside functions
    # solution = sum(values)
return solution

我省略了计算,因为它们很长,我的问题特别是关于我在参数空间(u,v)上实现计算的方式。我知道替代方案,例如numpy.meshgrid;然而,这些似乎都创建了(非常大的)矩阵的实例,我猜想将它们存储在内存中会减慢速度。

是否有替代方案it.product可以加快我的程序,或者我应该在别处寻找瓶颈?

编辑:这是有问题的 for 循环(看看它是否可以矢量化)。

import random  
import numpy as np  
import itertools as it 

##########################################################################
# Initialize the inputs with random (to save space)
##########################################################################
mat1 = np.array([[random.random() for i in range(3)] for i in range(3)])
mat2 = np.array([[random.random() for i in range(3)] for i in range(3)]) 
a1, a2, a3 = np.array([random.random() for i in range(3)]) 
plane_normal = np.array([random.random() for i in range(3)])  
plane_point = np.array([random.random() for i in range(3)])  
d = np.dot(plane_normal, plane_point)  
truthval = True

##########################################################################
# Initialize the loop
##########################################################################
N = 100 
U = np.linspace(0, 2*np.pi, N + 1, endpoint = False) 
V = np.linspace(0, np.pi, N + 1, endpoint = False) 
U = U[1:N+1] V = V[1:N+1]

Vsum = 0
Usum = 0

##########################################################################
# The for loops starts here
##########################################################################   
for (u, v) in it.product(U,V):

    cart_point = np.array([a1*np.cos(u)*np.sin(v), 
                           a2*np.sin(u)*np.sin(v), 
                           a3*np.cos(v)])

    surf_normal = np.array(
            [2*x / a**2 for (x, a) in zip(cart_point, [a1,a2,a3])])


    differential_area = \
        np.sqrt((a1*a2*np.cos(v)*np.sin(v))**2 + \
        a3**2*np.sin(v)**4 * \
        ((a2*np.cos(u))**2 + (a1*np.sin(u))**2)) * \
        (np.pi**2 / (2*N**2)) 


    if (np.dot(plane_normal, cart_point) - d > 0) == truthval:
        perp_normal = plane_normal
        f = np.dot(np.dot(mat2, surf_normal), perp_normal)
        Vsum += f*differential_area
    else:
        perp_normal = - plane_normal
        f = np.dot(np.dot(mat2, surf_normal), perp_normal)
        Usum += f*differential_area

integral = abs(Vsum) + abs(Usum)
4

3 回答 3

1

如果U.shape == (nu,)(V.shape == (nv,),则以下数组矢量化您的大部分计算。使用 numpy,您可以通过对最大维度使用数组并在较小维度上循环(例如 3x3)来获得最佳速度。

修正版

A = np.cos(U)[:,None]*np.sin(V)
B = np.sin(U)[:,None]*np.sin(V)
C = np.repeat(np.cos(V)[None,:],U.size,0)
CP = np.dstack([a1*A, a2*B, a3*C])

SN = np.dstack([2*A/a1, 2*B/a2, 2*C/a3])

DA1 = (a1*a2*np.cos(V)*np.sin(V))**2
DA2 = a3*a3*np.sin(V)**4
DA3 = (a2*np.cos(U))**2 + (a1*np.sin(U))**2
DA = DA1 + DA2 * DA3[:,None]
DA = np.sqrt(DA)*(np.pi**2 / (2*Nu*Nv))

D = np.dot(CP, plane_normal)
S = np.sign(D-d)

F1 = np.dot(np.dot(SN, mat2.T), plane_normal)
F = F1 * DA
#F = F * S # apply sign
Vsum = F[S>0].sum()
Usum = F[S<=0].sum()

使用相同的随机值,这会产生相同的值。在 100x100 的情况下,它快 10 倍。一年后玩这些矩阵很有趣。

于 2013-08-11T22:35:54.530 回答
1

在 ipython 中,我对您的 50 x 50 网格空间进行了简单的求和计算

In [31]: sum(u*v for (u,v) in it.product(U,V))
Out[31]: 12337.005501361698

In [33]: UU,VV = np.meshgrid(U,V); sum(sum(UU*VV))
Out[33]: 12337.005501361693

In [34]: timeit UU,VV = np.meshgrid(U,V); sum(sum(UU*VV))
1000 loops, best of 3: 293 us per loop

In [35]: timeit sum(u*v for (u,v) in it.product(U,V)) 
100 loops, best of 3: 2.95 ms per loop

In [38]: timeit list(it.product(U,V))
1000 loops, best of 3: 213 us per loop

In [45]: timeit UU,VV = np.meshgrid(U,V); (UU*VV).sum().sum()
10000 loops, best of 3: 70.3 us per loop
# using numpy's own sum is even better

product更慢(10 倍),不是因为product它本身很慢,而是因为逐点计算。如果您可以对计算进行矢量化,以便它们使用 2 (50,50) 个数组(没有任何循环),它应该会加快整体时间。这是使用numpy.

于 2013-08-11T00:11:22.713 回答
0

[k for k in it.product(U,V)]对我来说运行时间为 2 毫秒,并且 itertool 包非常高效,例如它不会首先创建一个长数组(http://docs.python.org/2/library/itertools.html)。

罪魁祸首似乎是您在迭代中的代码,或者您在 linspace 中使用了很多点。

于 2013-08-10T20:37:07.333 回答