1

我正在尝试优化以下代码,可能通过在 Cython 中重写它:它只需要一个低维但相对较长的 numpy 数组,查看其列中的 0 值,并将它们标记为数组中的 -1。代码是:

import numpy as np

def get_data():
    data = np.array([[1,5,1]] * 5000 + [[1,0,5]] * 5000 + [[0,0,0]] * 5000)
    return data

def get_cols(K):
    cols = np.array([2] * K)
    return cols

def test_nonzero(data):
    K = len(data)
    result = np.array([1] * K)
    # Index into columns of data
    cols = get_cols(K)
    # Mark zero points with -1
    idx = np.nonzero(data[np.arange(K), cols] == 0)[0]
    result[idx] = -1

import time
t_start = time.time()
data = get_data()
for n in range(5000):
    test_nonzero(data)
t_end = time.time()
print (t_end - t_start)

data是数据。cols是用于查找非零值的数据列数组(为简单起见,我将其全部设为同一列)。目标是计算一个 numpy 数组 ,result其中感兴趣的列非零的每一行的值为 1,感兴趣的相应列为零的行的值为 -1。

在 15,000 行乘 3 列的不太大的数组上运行此函数 5000 次大约需要 20 秒。有没有办法可以加快速度?似乎大部分工作都用于查找非零元素并使用索引检索它们(对其索引的调用nonzero和后续使用)。这可以优化还是可以做到最好?Cython 实施如何在这方面提高速度?

4

1 回答 1

2
cols = np.array([2] * K)

那会很慢。那是创建一个非常大的 python 列表,然后将其转换为一个 numpy 数组。相反,请执行以下操作:

cols = np.ones(K, int)*2

这样会快很多

result = np.array([1] * K)

在这里你应该这样做:

result = np.ones(K, int)

这将直接生成 numpy 数组。

idx = np.nonzero(data[np.arange(K), cols] == 0)[0]
result[idx] = -1

cols 是一个数组,但您可以只传递一个 2。此外,使用非零值会增加一个额外的步骤。

idx = data[np.arange(K), 2] == 0
result[idx] = -1

应该有同样的效果。

于 2013-04-16T04:13:35.800 回答