在 numpy.xml 中使用一些不同的技巧可能有一种更快的方法。numpy.indices
是你想开始的地方。itertools.product
一旦你将它与 结合起来,它本质上相当于rollaxis
。Sven Marnach在这个问题中的回答就是一个很好的例子(然而,他的最后一个例子中有一个小错误,这是你想要使用的。应该是numpy.indices((len(some_list) + 1), * some_length...
)
但是,对于这样的事情,使用 itertools 可能更具可读性。
numpy.fromiter
比将事物显式转换为列表要快一点,特别是如果你给它一个迭代器中项目数的计数。主要优点是使用的内存要少得多,这在大型迭代器的情况下非常有用。
以下是一些使用numpy.indices
技巧和各种将迭代器转换为 numpy 数组的方法的示例:
import itertools
import numpy as np
import scipy.special
def fixed_total_product(bins, num_items):
return itertools.ifilter(lambda combo: sum(combo) == num_items,
itertools.product(xrange(num_items + 1), repeat=bins))
def fixed_total_product_fromiter(bins, num_items):
size = scipy.special.binom(bins - 1 + num_items, num_items)
combinations = fixed_total_product(bins, num_items)
indicies = (idx for row in combinations for idx in row)
arr = np.fromiter(indicies, count=bins * int(size), dtype=np.int)
return arr.reshape(-1, bins)
def fixed_total_product_fromlist(bins, num_items):
return np.array(list(fixed_total_product(bins, num_items)), dtype=np.int)
def fixed_total_product_numpy(bins, num_items):
arr = np.rollaxis(np.indices((num_items + 1,) * bins), 0, bins + 1)
arr = arr.reshape(-1, bins)
arr = np.arange(num_items + 1)[arr]
sums = arr.sum(axis=1)
return arr[sums == num_items]
m, n = 5, 20
if __name__ == '__main__':
import timeit
list_time = timeit.timeit('fixed_total_product_fromlist(m, n)',
setup='from __main__ import fixed_total_product_fromlist, m, n',
number=1)
iter_time = timeit.timeit('fixed_total_product_fromiter(m, n)',
setup='from __main__ import fixed_total_product_fromiter, m, n',
number=1)
numpy_time = timeit.timeit('fixed_total_product_numpy(m, n)',
setup='from __main__ import fixed_total_product_numpy, m, n',
number=1)
print 'All combinations of {0} items drawn from a set of {1} items...'.format(m,n)
print ' Converting to a list and then an array: {0} sec'.format(list_time)
print ' Using fromiter: {0} sec'.format(iter_time)
print ' Using numpy.indices: {0} sec'.format(numpy_time)
至于时机:
All combinations of 5 items drawn from a set of 20 items...
Converting to a list and then an array: 2.75901389122 sec
Using fromiter: 2.10619592667 sec
Using numpy.indices: 1.44955015182 sec
您会注意到它们都不是特别快。
您正在使用蛮力算法(生成所有可能的组合,然后过滤它们),我只是在此处基于 numpy 的示例中复制它。
可能有一种更有效的方法来做到这一点!但是,我无论如何都不是 CS 或数学人,所以我不知道是否有一个众所周知的算法可以在不首先生成所有可能的组合的情况下做到这一点......
无论如何,祝你好运!