2

我正在对数据进行一维小波变换。我怎样才能让它更快?我有 140 万个样本和 32 个特征。

def apply_wavelet_transform(data):
    ca,cd=pywt.dwt(data[0,:],'haar')
    for i in range(1,data.shape[0]):
        ca_i,__=pywt.dwt(data[i,:],'haar')
        ca=np.vstack((ca,ca_i))
    return ca

考虑到我不关心内存使用和执行速度。

4

1 回答 1

4

这是一个常见的错误。您不想一次将行追加到数组中,因为每次迭代都需要复制整个数组。复杂度:O(N**2)。将中间结果保存在列表中并在最后形成数组要好得多。这更好,因为列表不需要它们的元素在内存中是连续的,因此不需要复制。

def apply_wavelet_transform(data):
    results_list = []
    for row in data:
        ca, cd = pywt.dwt(row, 'haar')
        results_list.append(ca)
    result = np.array(results_list)
    return result
于 2015-08-17T03:34:48.733 回答