0

我正在尝试使用花哨的索引而不是循环来加速 Numpy 中的函数。据我所知,我已经正确实现了精美的索引版本。问题是这两个函数(循环和花式索引)不返回相同的结果。我不确定为什么。值得指出的是,如果使用较小的数组(例如,20 x 20 x 20),这些函数会返回相同的结果。

下面我已经包含了重现错误所需的所有内容。如果函数确实返回相同的结果,则该行find_maxdiff(data) - find_maxdiff_fancy(data)应返回一个全零的数组。

from numpy import *

def rms(data, axis=0):
    return sqrt(mean(data ** 2, axis))

def find_maxdiff(data):
    samples, channels, epochs = shape(data)
    window_size = 50
    maxdiff = zeros(epochs)
    for epoch in xrange(epochs):
        signal = rms(data[:, :, epoch], axis=1)
        for t in xrange(window_size, alen(signal) - window_size):
            amp_a = mean(signal[t-window_size:t], axis=0)
            amp_b = mean(signal[t:t+window_size], axis=0)
            the_diff = abs(amp_b - amp_a)
            if the_diff > maxdiff[epoch]: 
                maxdiff[epoch] = the_diff

    return maxdiff

def find_maxdiff_fancy(data):
    samples, channels, epochs = shape(data)
    window_size = 50
    maxdiff = zeros(epochs)
    signal = rms(data, axis=1)
    for t in xrange(window_size, alen(signal) - window_size):
        amp_a = mean(signal[t-window_size:t], axis=0)
        amp_b = mean(signal[t:t+window_size], axis=0)
        the_diff = abs(amp_b - amp_a)
        maxdiff[the_diff > maxdiff] = the_diff

    return maxdiff

data = random.random((600, 20, 100))
find_maxdiff(data) - find_maxdiff_fancy(data)

data = random.random((20, 20, 20))
find_maxdiff(data) - find_maxdiff_fancy(data)
4

2 回答 2

3

问题是这一行:

maxdiff[the_diff > maxdiff] = the_diff

左边只选择了maxdiff的部分元素,而右边包含了the_diff的所有元素。这应该起作用:

replaceElements = the_diff > maxdiff
maxdiff[replaceElements] = the_diff[replaceElements]

或者简单地说:

maxdiff = maximum(maxdiff, the_diff)

至于为什么 20x20x20 尺寸似乎有效:这是因为您的窗口尺寸太大,所以没有执行任何操作。

于 2009-11-23T13:47:47.673 回答
0

首先,如果我理解正确,您的信号现在是二维的 - 所以我认为明确索引它会更清楚(例如 amp_a = mean(signal[t-window_size:t,:], axis=0)。与 alen 类似(信号) - 这应该只是两种情况下的样本,所以我认为使用它会更清楚。

每当您实际上在t循环中做某事时,这是错误的——samples < window_lenght就像在 20x20x20 示例中那样,该循环永远不会被执行。一旦该循环多次执行(即samples > 2 *window_length+1),错误就会出现。不知道为什么——它们看起来确实和我一样。

于 2009-11-23T10:44:09.230 回答