我正在尝试优化以下代码,可能通过在 Cython 中重写它:它只需要一个低维但相对较长的 numpy 数组,查看其列中的 0 值,并将它们标记为数组中的 -1。代码是:
import numpy as np
def get_data():
data = np.array([[1,5,1]] * 5000 + [[1,0,5]] * 5000 + [[0,0,0]] * 5000)
return data
def get_cols(K):
cols = np.array([2] * K)
return cols
def test_nonzero(data):
K = len(data)
result = np.array([1] * K)
# Index into columns of data
cols = get_cols(K)
# Mark zero points with -1
idx = np.nonzero(data[np.arange(K), cols] == 0)[0]
result[idx] = -1
import time
t_start = time.time()
data = get_data()
for n in range(5000):
test_nonzero(data)
t_end = time.time()
print (t_end - t_start)
data
是数据。cols
是用于查找非零值的数据列数组(为简单起见,我将其全部设为同一列)。目标是计算一个 numpy 数组 ,result
其中感兴趣的列非零的每一行的值为 1,感兴趣的相应列为零的行的值为 -1。
在 15,000 行乘 3 列的不太大的数组上运行此函数 5000 次大约需要 20 秒。有没有办法可以加快速度?似乎大部分工作都用于查找非零元素并使用索引检索它们(对其索引的调用nonzero
和后续使用)。这可以优化还是可以做到最好?Cython 实施如何在这方面提高速度?