4

这是我第一次尝试在 numpy 中使用 strides,与在不同过滤器上的简单迭代相比,它确实提高了速度,但它仍然很慢(感觉至少有一两件事是完全冗余或低效的) .

所以我的问题是:是否有更好的方法来执行此操作或对我的代码进行调整以使其显着更快?

该算法对每个像素执行 9 个不同过滤器的局部评估,并选择具有最小标准偏差的过滤器(我尝试实现 Nagau 和 Matsuyma (1980) “复杂区域照片的结构分析”,如图像分析中所述书)。结果是平滑和边缘锐化的图像(如果你问我,这很酷!)

import numpy as np
from scipy import ndimage
from numpy.lib import stride_tricks

def get_rotating_kernels():

    kernels = list()

    protokernel = np.arange(9).reshape(3,  3)

    for k in xrange(9):

        ax1, ax2 = np.where(protokernel==k)
        kernel = np.zeros((5,5), dtype=bool)
        kernel[ax1: ax1+3, ax2: ax2+3] = 1
        kernels.append(kernel)

    return kernels


def get_rotation_smooth(im, **kwargs):

    kernels = np.array([k.ravel() for k in get_rotating_kernels()],
                dtype=bool)

    def rotation_matrix(section):

        multi_s = stride_tricks.as_strided(section, shape=(9,25),
            strides=(0, section.itemsize))

        rot_filters = multi_s[kernels].reshape(9,9)

        return rot_filters[rot_filters.std(1).argmin(),:].mean()

    return ndimage.filters.generic_filter(im, rotation_matrix, size=5, **kwargs)

from scipy import lena
im = lena()
im2 = get_rotation_smooth(im)

(只是评论,get_rotating_kernel还没有真正优化,因为几乎没有时间花在那里)

在我的上网本上,它花了 126 秒,而莉娜毕竟是一个很小的图像。

编辑:

我得到了更改rot_filters.std(1)rot_filters.var(1)以保存相当多的平方根的建议,并且它以 5 秒的顺序剃掉了一些东西。

4

2 回答 2

1

我相信您将很难使用 Python + 进行显着优化scipy。但是,我能够通过使用直接as_strided生成rot_filters而不是通过布尔索引进行小幅改进。这是基于一个非常简单的 n 维windows函数。(我意识到 2d 卷积函数存在于scipy. 有关其工作原理的说明,请参见下文:

import numpy as np
from scipy import ndimage
from numpy.lib import stride_tricks

# pass in `as_strided` as a default arg to save a global lookup
def rotation_matrix2(section, _as_strided=stride_tricks.as_strided):
    section = section.reshape(5, 5)  # sqrt(section.size), sqrt(section.size)
    windows_shape = (3, 3, 3, 3)     # 5 - 3 + 1, 5 - 3 + 1, 3, 3
    windows_strides = section.strides + section.strides
    windows = _as_strided(section, windows_shape, windows_strides)
    rot_filters = windows.reshape(9, 9)
    return rot_filters[rot_filters.std(1).argmin(),:].mean()

def get_rotation_smooth(im, _rm=rotation_matrix2, **kwargs):
    return ndimage.filters.generic_filter(im, _rm, size=5, **kwargs)

if __name__ == '__main__':
    import matplotlib.pyplot as plt
    from scipy.misc import lena
    im = lena()
    im2 = get_rotation_smooth(im)
    #plt.gray()      # Uncomment these lines for
    #plt.imshow(im2) # demo purposes.
    #plt.show()

上面的函数rotation_matrix2等价于以下两个函数(它们实际上比你的原始函数慢一点,因为windows更通用)。这正是您的原始代码所做的——将 9 个 3x3 窗口创建为一个 5x5 数组,然后将它们重新整形为一个 9x9 数组进行处理。

def windows(a, w, _as_strided=stride_tricks.as_strided):
    windows_shape = tuple(sa - sw + 1 for sa, sw in zip(a.shape, w))
    windows_shape += w
    windows_strides = a.strides + a.strides
    return _as_strided(a, windows_shape, windows_strides)

def rotation_matrix1(section, _windows=windows):
    rot_filters = windows(section.reshape(5, 5), (3, 3)).reshape(9, 9)
    return rot_filters[rot_filters.std(1).argmin(),:].mean()

windows只要窗口具有相同的维数,就可以处理任何维数的数组。以下是其工作原理的细分:

    windows_shape = tuple(sa - sw + 1 for sa, sw in zip(a.shape, w))

我们可以将windows数组视为 nd 数组的 nd 数组。外部 nd 阵列的形状由窗口在较大阵列内的自由度决定;在每个维度上,窗口可以占据的位置数等于较大数组的长度减去窗口的长度加一。在这种情况下,我们有一个 3x3 窗口到一个 5x5 数组中,所以外部二维数组是一个 3x3 数组。

    windows_shape += w

内部 nd 数组的形状与窗口本身的形状相同。在我们的例子中,这又是一个 3x3 数组。

现在大步前进。我们必须为外部 nd 数组和内部 nd 数组定义步幅。但事实证明它们是一样的!毕竟,窗口在更大的数组中移动的方式与单个索引在数组中移动的方式相同,对吧?

    windows_strides = a.strides + a.strides

现在我们有了创建窗口所需的所有信息:

    return _as_strided(a, windows_shape, windows_strides)
于 2012-10-19T16:06:43.080 回答
1

对于复杂的每像素 + 邻域操作,您可以考虑使用 cython 来提高性能。它允许以接近 Python 语法的 for 循环高效地编写代码,该语法稍后会转换为 C 代码。

为了获得灵感,您可以查看 scikit-image 代码,例如:

https://github.com/scikit-image/scikit-image/blob/master/skimage/filter/_denoise.pyx

于 2012-10-19T08:54:58.047 回答