21

我正在尝试改进为图像的每个像素计算位于像素附近的像素的标准偏差的函数。我的函数使用两个嵌入式循环在矩阵上运行,这是我的程序的瓶颈。我想可能有一种方法可以通过摆脱循环来改进它,这要归功于 numpy,但我不知道如何继续。欢迎任何建议!

问候

def sliding_std_dev(image_original,radius=5) :
    height, width = image_original.shape
    result = np.zeros_like(image_original) # initialize the output matrix
    hgt = range(radius,height-radius)
    wdt = range(radius,width-radius)
    for i in hgt:
        for j in wdt:
            result[i,j] = np.std(image_original[i-radius:i+radius,j-radius:j+radius])
    return result
4

5 回答 5

32

很酷的技巧:您可以仅给定平方值的总和和窗口中值的总和来计算标准偏差。

因此,您可以使用数据上的统一过滤器非常快速地计算标准偏差:

from scipy.ndimage.filters import uniform_filter

def window_stdev(arr, radius):
    c1 = uniform_filter(arr, radius*2, mode='constant', origin=-radius)
    c2 = uniform_filter(arr*arr, radius*2, mode='constant', origin=-radius)
    return ((c2 - c1*c1)**.5)[:-radius*2+1,:-radius*2+1]

这比原始功能快得离谱。对于 1024x1024 的数组,半径为 20,旧函数耗时 34.11 秒,新函数耗时0.11 秒,速度提升了 300 倍。


这在数学上是如何工作的?它计算sqrt(mean(x^2) - mean(x)^2)每个窗口的数量。我们可以从标准差中得出这个量,sqrt(mean((x - mean(x))^2))如下所示:

E是期望算子(基本上是mean()),并且X是数据的随机变量。然后:

E[(X - E[X])^2]
= E[X^2 - 2X*E[X] + E[X]^2]
= E[X^2] - E[2X*E[X]] + E[E[X]^2](通过期望算子的线性) (再次通过线性,以及常数
= E[X^2] - 2E[X]*E[X] + E[X]^2的事实)E[X]
= E[X^2] - E[X]^2

这证明了使用这种技术计算的数量在数学上等同于标准偏差。

于 2013-08-24T19:50:07.340 回答
13

在图像处理中最常用的方法是使用求和面积表,这是 1984 年在这篇论文中提出的一个想法。这个想法是,当你通过在窗口上加法来计算数量时,然后移动窗口,例如右边一个像素,不需要把新窗口中的所有项都加进去,只需要从总数中减去最左边的那一列,再加上新的最右边那一列。因此,如果您从数组的两个维度上创建一个累加和数组,您可以在一个窗口上通过几个和和一个减法获得总和。如果您为数组及其正方形保留总面积表,则很容易从这两个中获得方差。这是一个实现:

def windowed_sum(a, win):
    table = np.cumsum(np.cumsum(a, axis=0), axis=1)
    win_sum = np.empty(tuple(np.subtract(a.shape, win-1)))
    win_sum[0,0] = table[win-1, win-1]
    win_sum[0, 1:] = table[win-1, win:] - table[win-1, :-win]
    win_sum[1:, 0] = table[win:, win-1] - table[:-win, win-1]
    win_sum[1:, 1:] = (table[win:, win:] + table[:-win, :-win] -
                       table[win:, :-win] - table[:-win, win:])
    return win_sum

def windowed_var(a, win):
    win_a = windowed_sum(a, win)
    win_a2 = windowed_sum(a*a, win)
    return (win_a2 - win_a * win_a / win/ win) / win / win

要查看这是否有效:

>>> a = np.arange(25).reshape(5,5)
>>> windowed_var(a, 3)
array([[ 17.33333333,  17.33333333,  17.33333333],
       [ 17.33333333,  17.33333333,  17.33333333],
       [ 17.33333333,  17.33333333,  17.33333333]])
>>> np.var(a[:3, :3])
17.333333333333332
>>> np.var(a[-3:, -3:])
17.333333333333332

这应该比基于卷积的方法快几个缺口。

于 2013-08-24T22:36:44.053 回答
3

首先,有不止一种方法可以做到这一点。

这不是最有效的速度,但使用scipy.ndimage.generic_filter将允许您轻松地在移动窗口上应用任意 python 函数。

举个简单的例子:

result = scipy.ndimage.generic_filter(data, np.std, size=2*radius)

请注意,边界条件可以由modekwarg 控制。


另一种方法是使用一些不同的跨步技巧来查看实际上是移动窗口的数组视图,然后np.std沿最后一个轴应用。(注意:这取自我以前的答案之一:https ://stackoverflow.com/a/4947453/325565 )

def strided_sliding_std_dev(data, radius=5):
    windowed = rolling_window(data, (2*radius, 2*radius))
    shape = windowed.shape
    windowed = windowed.reshape(shape[0], shape[1], -1)
    return windowed.std(axis=-1)

def rolling_window(a, window):
    """Takes a numpy array *a* and a sequence of (or single) *window* lengths
    and returns a view of *a* that represents a moving window."""
    if not hasattr(window, '__iter__'):
        return rolling_window_lastaxis(a, window)
    for i, win in enumerate(window):
        if win > 1:
            a = a.swapaxes(i, -1)
            a = rolling_window_lastaxis(a, win)
            a = a.swapaxes(-2, i)
    return a

def rolling_window_lastaxis(a, window):
    """Directly taken from Erik Rigtorp's post to numpy-discussion.
    <http://www.mail-archive.com/numpy-discussion@scipy.org/msg29450.html>"""
    if window < 1:
       raise ValueError, "`window` must be at least 1."
    if window > a.shape[-1]:
       raise ValueError, "`window` is too long."
    shape = a.shape[:-1] + (a.shape[-1] - window + 1, window)
    strides = a.strides + (a.strides[-1],)
    return np.lib.stride_tricks.as_strided(a, shape=shape, strides=strides)

乍一看有点难以理解这里发生了什么。不要插入我自己的答案之一,但我不想重新输入解释,所以看看这里:https ://stackoverflow.com/a/4924433/325565如果你还没有看到这些之前的“大步”技巧。

如果我们将时序与一个 100x100 的随机浮点数组与 aradius为 5 进行比较,它比原始版本或generic_filter版本快约 10 倍。但是,您在此版本的边界条件上没有灵活性。(它与您当前所做的相同,而该generic_filter版本以牺牲速度为代价为您提供了很大的灵活性。)

# Your original function with nested loops
In [21]: %timeit sliding_std_dev(data)
1 loops, best of 3: 237 ms per loop

# Using scipy.ndimage.generic_filter
In [22]: %timeit ndimage_std_dev(data)
1 loops, best of 3: 244 ms per loop

# The "stride-tricks" version above
In [23]: %timeit strided_sliding_std_dev(data)
100 loops, best of 3: 15.4 ms per loop

# Ophion's version that uses `np.take`
In [24]: %timeit new_std_dev(data)
100 loops, best of 3: 19.3 ms per loop

“stride-tricks”版本的缺点是,与“正常”的跨步滚动窗口技巧不同,这个版本确实会复制,并且比原始数组大得多如果您在大型阵列上使用它,您遇到内存问题!(附带说明一下,在内存使用和速度方面,它基本上等同于@Ophion 的答案。这只是做同样事情的不同方法。)

于 2013-08-24T16:36:19.647 回答
1

您可以先获取索引,然后使用它np.take来形成新数组:

def new_std_dev(image_original,radius=5):
    cols,rows=image_original.shape

    #First obtain the indices for the top left position
    diameter=np.arange(radius*2)
    x,y=np.meshgrid(diameter,diameter)
    index=np.ravel_multi_index((y,x),(cols,rows)).ravel()

    #Cast this in two dimesions and take the stdev
    index=index+np.arange(rows-radius*2)[:,None]+np.arange(cols-radius*2)[:,None,None]*(rows)
    data=np.std(np.take(image_original,index),-1)

    #Add the zeros back to the output array
    top=np.zeros((radius,rows-radius*2))
    sides=np.zeros((cols,radius))

    data=np.vstack((top,data,top))
    data=np.hstack((sides,data,sides))
    return data

首先生成一些随机数据并检查时序:

a=np.random.rand(50,20)

print np.allclose(new_std_dev(a),sliding_std_dev(a))
True

%timeit sliding_std_dev(a)
100 loops, best of 3: 18 ms per loop

%timeit new_std_dev(a)
1000 loops, best of 3: 472 us per loop

对于较大的数组,只要您有足够的内存,它总是更快:

a=np.random.rand(200,200)

print np.allclose(new_std_dev(a),sliding_std_dev(a))
True

%timeit sliding_std_dev(a)
1 loops, best of 3: 1.58 s per loop

%timeit new_std_dev(a)
10 loops, best of 3: 52.3 ms per loop

对于非常小的数组,原始函数更快,看起来盈亏平衡点是 when hgt*wdt >50。需要注意的是,您的函数采用方形框架并将 std dev 放在右下角的索引中,而不是在索引周围采样。这是故意的吗?

于 2013-08-24T16:44:12.840 回答
0

在尝试在这里使用几个优秀的解决方案之后,我在处理包含 NaN 的数据时遇到了麻烦。uniform_filter和解决方案都np.cumsum()导致 Nan's 通过输出数组传播,而不是被忽略。

我下面的解决方案基本上只是将@Jaime 答案中的加窗求和函数与卷积交换,这对 NaN 是稳健的。

def windowed_sum(arr: np.ndarray, radius: int) -> np.ndarray:
    """radius=1 means the pixel itself and the 8 surrounding pixels"""

    kernel = np.ones((radius * 2 + 1, radius * 2 + 1), dtype=int)
    return convolve(arr, kernel, mode="constant", cval=0.0)

def windowed_var(arr: np.ndarray, radius: int) -> np.ndarray:
    """Note: this returns smaller in size than the input array (by radius)"""

    diameter = radius * 2 + 1
    win_sum = windowed_sum(arr, radius)[radius:-radius, radius:-radius]
    win_sum_2 = windowed_sum(arr * arr, radius)[radius:-radius, radius:-radius]
    return (win_sum_2 - win_sum * win_sum / diameter / diameter) / diameter / diameter

def windowed_std(arr: np.ndarray, radius: int) -> np.ndarray:

    output = np.full_like(arr, np.nan, dtype=np.float64)

    var_arr = windowed_var(arr, radius)
    std_arr = np.sqrt(var_arr)
    output[radius:-radius, radius:-radius] = std_arr

    return output

这比 执行得慢一点uniform_filter,但仍然比许多其他方法(堆叠数组、迭代等)快得多

>>> data = np.random.random((1024, 1024))
>>> %timeit windowed_std(data, 4)
158 ms ± 695 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)

与之相比,uniform_filter对于相同大小的数据执行大约 36 毫秒

有一些 NaN 的:

data = np.arange(100, dtype=np.float64).reshape(10, 10)
data[3:4, 3:4] = np.nan
windowed_std(data, 1)

array([[ nan,  nan,  nan,  nan,  nan,  nan,  nan,  nan,  nan,  nan],
       [ nan, 8.21, 8.21, 8.21, 8.21, 8.21, 8.21, 8.21, 8.21,  nan],
       [ nan, 8.21,  nan,  nan,  nan, 8.21, 8.21, 8.21, 8.21,  nan],
       [ nan, 8.21,  nan,  nan,  nan, 8.21, 8.21, 8.21, 8.21,  nan],
       [ nan, 8.21,  nan,  nan,  nan, 8.21, 8.21, 8.21, 8.21,  nan],
       [ nan, 8.21, 8.21, 8.21, 8.21, 8.21, 8.21, 8.21, 8.21,  nan],
       [ nan, 8.21, 8.21, 8.21, 8.21, 8.21, 8.21, 8.21, 8.21,  nan],
       [ nan, 8.21, 8.21, 8.21, 8.21, 8.21, 8.21, 8.21, 8.21,  nan],
       [ nan, 8.21, 8.21, 8.21, 8.21, 8.21, 8.21, 8.21, 8.21,  nan],
       [ nan,  nan,  nan,  nan,  nan,  nan,  nan,  nan,  nan,  nan]])
于 2022-02-16T03:05:54.517 回答