我在下面写了一段代码来了解多处理(MP)及其与非 MP 版本相比可能的速度增益。这两个功能几乎相同,除了突出显示的地方(对不起,不知道突出代码区域的更好方法)。
该代码尝试识别数组列表(此处为一维)中冗余条目的索引。这两个函数返回的 id-lists 是相同的,但我的问题是关于时差。
正如你所看到的,在这两种情况下,我都尝试对 a)map 函数、b)list 扩展和 c)整个 while 循环进行计时。map
MP 在该地区提供更好的加速,但redun_ids.extend(...)
与非 MP 版本相比速度较慢。这实际上迫使 MP 版本的整体速度增益下降。
有什么办法可以重写 MP 版本的redun_ids.extend(...)
部分以使时间与非 MP 版本相同?
#!/usr/bin/python
import multiprocessing as mproc
import sys
import numpy as np
import random
import time
def matdist(mats):
mat1 = mats[0]
mat2 = mats[1]
return np.allclose(mat1, mat2, rtol=1e-08, atol=1e-12)
def mp_remove_redundancy(larrays):
"""
remove_redundancy : identify arrays that are redundant in the
input list of arrays
"""
llen = len(larrays)
redun_ids = list()
templist = list()
i = 0
**pool = mproc.Pool(processes=10)**
st1=time.time()
while 1:
currarray = larrays[i]
if i not in redun_ids:
templist.append(currarray)
#replication to create list of arrays
templist = templist*(llen-i-1)
**chunksize = len(templist)/10
if chunksize == 0:
chunksize = 1**
#clslist is a result object here
st=time.time()
**clslist = pool.map_async(matdist, zip(larrays[i+1:],
templist), chunksize)**
print 'map time:', time.time()-st
**outlist = clslist.get()[:]**
#j+1+i gives the actual id num w.r.t to whole list
st=time.time()
redun_ids.extend([j+1+i for j, x in
enumerate(outlist) if x])
print 'Redun ids extend time:', time.time()-st
i = i + 1
del templist[:]
del outlist[:]
if i == (llen - 1):
break
print 'Time elapsed in MP:', time.time()-st1
pool.close()
pool.join()
del clslist
del templist
return redun_ids[:]
#######################################################
def remove_redundancy(larrays):
llen = len(larrays)
redun_ids = list()
clslist = list()
templist = list()
i = 0
st1=time.time()
while 1:
currarray = larrays[i]
if i not in redun_ids:
templist.append(currarray)
templist = templist*(llen-i-1)
st = time.time()
clslist = map(matdist, zip(larrays[i+1:],
templist))
print 'map time:', time.time()-st
#j+1+i gives the actual id num w.r.t to whole list
st=time.time()
redun_ids.extend([j+1+i for j, x in
enumerate(clslist) if x])
print 'Redun ids extend time:', time.time()-st
i = i + 1
#clear temp vars
del clslist[:]
del templist[:]
if i == (llen - 1):
break
print 'Tot non MP time:', time.time()-st1
del clslist
del templist
return redun_ids[:]
###################################################################
if __name__=='__main__':
if len(sys.argv) != 2:
sys.exit('# entries')
llen = int(sys.argv[1])
#generate random numbers between 1 and 10
mylist=[np.array([round(random.random()*9+1)]) for i in range(llen)]
print 'The input list'
print 'no MP'
rrlist = remove_redundancy(mylist)
print 'MP'
rrmplist = mp_remove_redundancy(mylist)
print 'Two lists match : {0}'.format(rrlist==rrmplist)