12

我在几个for循环中多次使用 numpy 的 where 函数,但它变得太慢了。有什么方法可以更快地执行此功能?我读到您应该尝试进行内联 for 循环,并在循环之前为函数创建局部变量for,但似乎没有什么可以提高速度(< 1%)。len(UNIQ_IDS)~ 800.emiss_data并且obj_data是形状 = (2600,5200) 的 numpy ndarray 。我过去常常import profile掌握瓶颈所在的位置,而where循环中的瓶颈for很大。

import numpy as np
max = np.max
where = np.where
MAX_EMISS = [max(emiss_data[where(obj_data == i)]) for i in UNIQ_IDS)]
4

6 回答 6

10

事实证明,在这种情况下,纯 Python 循环比 NumPy 索引(或对 np.where 的调用)要快得多。

考虑以下替代方案:

import numpy as np
import collections
import itertools as IT

shape = (2600,5200)
# shape = (26,52)
emiss_data = np.random.random(shape)
obj_data = np.random.random_integers(1, 800, size=shape)
UNIQ_IDS = np.unique(obj_data)

def using_where():
    max = np.max
    where = np.where
    MAX_EMISS = [max(emiss_data[where(obj_data == i)]) for i in UNIQ_IDS]
    return MAX_EMISS

def using_index():
    max = np.max
    MAX_EMISS = [max(emiss_data[obj_data == i]) for i in UNIQ_IDS]
    return MAX_EMISS

def using_max():
    MAX_EMISS = [(emiss_data[obj_data == i]).max() for i in UNIQ_IDS]
    return MAX_EMISS

def using_loop():
    result = collections.defaultdict(list)
    for val, idx in IT.izip(emiss_data.ravel(), obj_data.ravel()):
        result[idx].append(val)
    return [max(result[idx]) for idx in UNIQ_IDS]

def using_sort():
    uind = np.digitize(obj_data.ravel(), UNIQ_IDS) - 1
    vals = uind.argsort()
    count = np.bincount(uind)
    start = 0
    end = 0
    out = np.empty(count.shape[0])
    for ind, x in np.ndenumerate(count):
        end += x
        out[ind] = np.max(np.take(emiss_data, vals[start:end]))
        start += x
    return out

def using_split():
    uind = np.digitize(obj_data.ravel(), UNIQ_IDS) - 1
    vals = uind.argsort()
    count = np.bincount(uind)
    return [np.take(emiss_data, item).max()
            for item in np.split(vals, count.cumsum())[:-1]]

for func in (using_index, using_max, using_loop, using_sort, using_split):
    assert using_where() == func()

以下是基准,其中shape = (2600,5200)

In [57]: %timeit using_loop()
1 loops, best of 3: 9.15 s per loop

In [90]: %timeit using_sort()
1 loops, best of 3: 9.33 s per loop

In [91]: %timeit using_split()
1 loops, best of 3: 9.33 s per loop

In [61]: %timeit using_index()
1 loops, best of 3: 63.2 s per loop

In [62]: %timeit using_max()
1 loops, best of 3: 64.4 s per loop

In [58]: %timeit using_where()
1 loops, best of 3: 112 s per loop

因此using_loop(纯 Python)比using_where.

我不完全确定为什么纯 Python 比 NumPy 更快。我的猜测是纯 Python 版本通过两个数组压缩一次(是的,双关语)。它利用了一个事实,即尽管有所有花哨的索引,但我们真的只想访问每个值一次。因此,它回避了必须准确确定每个值属于哪个组的问题emiss_data。但这只是模糊的推测。在我进行基准测试之前,我不知道它会更快。

于 2013-08-26T21:15:45.800 回答
8

可以np.unique使用return_index

def using_sort():
    #UNIQ_IDS,uind=np.unique(obj_data, return_inverse=True)
    uind= np.digitize(obj_data.ravel(), UNIQ_IDS) - 1
    vals=uind.argsort()
    count=np.bincount(uind)

    start=0
    end=0

    out=np.empty(count.shape[0])
    for ind,x in np.ndenumerate(count):
        end+=x
        out[ind]=np.max(np.take(emiss_data,vals[start:end]))
        start+=x
    return out

使用@unutbu 的回答作为基线shape = (2600,5200)

np.allclose(using_loop(),using_sort())
True

%timeit using_loop()
1 loops, best of 3: 12.3 s per loop

#With np.unique inside the definition
%timeit using_sort()
1 loops, best of 3: 9.06 s per loop

#With np.unique outside the definition 
%timeit using_sort()
1 loops, best of 3: 2.75 s per loop

#Using @Jamie's suggestion for uind
%timeit using_sort()
1 loops, best of 3: 6.74 s per loop
于 2013-08-26T21:49:31.377 回答
5

我相信最快的方法是使用包groupby()中的操作pandas。与@Ophion 的using_sort()函数相比,Pandas 的速度快了大约 10 倍:

import numpy as np
import pandas as pd

shape = (2600,5200)
emiss_data = np.random.random(shape)
obj_data = np.random.random_integers(1, 800, size=shape)
UNIQ_IDS = np.unique(obj_data)

def using_sort():
    #UNIQ_IDS,uind=np.unique(obj_data, return_inverse=True)
    uind= np.digitize(obj_data.ravel(), UNIQ_IDS) - 1
    vals=uind.argsort()
    count=np.bincount(uind)

    start=0
    end=0

    out=np.empty(count.shape[0])
    for ind,x in np.ndenumerate(count):
        end+=x
        out[ind]=np.max(np.take(emiss_data,vals[start:end]))
        start+=x
    return out

def using_pandas():
    return pd.Series(emiss_data.ravel()).groupby(obj_data.ravel()).max()

print('same results:', np.allclose(using_pandas(), using_sort()))
# same results: True

%timeit using_sort()
# 1 loops, best of 3: 3.39 s per loop

%timeit using_pandas()
# 1 loops, best of 3: 397 ms per loop
于 2015-10-24T17:56:00.327 回答
3

你不能就这样吗

emiss_data[obj_data == i]

? 我不确定你为什么要使用where

于 2013-08-26T20:47:40.237 回答
0

分配元组比分配列表要快得多,根据元组比 Python 中的列表更有效吗?,所以也许只是通过构建一个元组而不是一个列表,这将提高效率。

于 2013-08-26T21:01:16.343 回答
0

如果obj_data由相对较小的整数组成,您可以使用numpy.maximum.at(自 v1.8.0 起):

def using_maximumat():
    n = np.max(UNIQ_IDS) + 1
    temp = np.full(n, -np.inf)
    np.maximum.at(temp, obj_data, emiss_data)
    return temp[UNIQ_IDS]
于 2015-10-25T15:32:29.333 回答