我知道 maxpool 并且我在 pytorch 中使用它。带扩张参数的 Maxpool 如下:
现在我想要一种特殊形式的 maxpool,在没有中心元素的情况下做 maxpool。也就是说内核大小是 3X3,但应该删除中心元素。因此,结果应该来自其余 8 个元素。
现在我正在使用 for 循环,如何使用 numpy 或 pytorch 或其他任何东西来加速它?
import numpy as np
from timeit import default_timer as timer
def MaxPool_special(kh, kw, arr):
"""
to do maxpool without central element
:param kh: should always be 3
:param kw: should always be 3
:param arr: the input array
:return: arr_res: output array
"""
h, w = arr.shape[:2]
arr_res = np.array([[maxpool_ij(i, j, arr, kh, kw) for j in range(w)] for i in range(h)])
return arr_res
def maxpool_ij(i, j, arr, dh, dw):
"""
find the maximum value around point(i,j) with dilated parameter
"""
Mmax = None
imin, imax = i - dh, i + dh
jmin, jmax = j - dw, j + dw
if imin >= 0 and imax < h and jmin >= 0 and jmax < w:
Mmax = np.max(
arr[[imin, imin, imin, i, i, imax, imax, imax], [jmin, j, jmax, jmin, jmax, jmin, j, jmax]])
elif imin < 0 and jmin < 0:
Mmax = np.max(arr[[i, imax, imax], [jmax, j, jmax]])
elif imin < 0 and jmax >= w:
Mmax = np.max(arr[[i, imax, imax], [jmin, jmin, j]])
elif imax >= h and jmin < 0:
Mmax = np.max(arr[[imin, imin, i], [j, jmax, jmax]])
elif imax >= h and jmax >= w:
Mmax = np.max(arr[[imin, imin, i], [jmin, j, jmin]])
elif imin < 0:
Mmax = np.max(arr[[i, i, imax, imax, imax], [jmin, jmax, jmin, j, jmax]])
elif imax >= h:
Mmax = np.max(arr[[imin, imin, imin, i, i], [jmin, j, jmax, jmin, jmax]])
elif jmin < 0:
Mmax = np.max(arr[[imin, imin, i, imax, imax], [j, jmax, jmax, j, jmax]])
elif jmax >= w:
Mmax = np.max(arr[[imin, imin, i, imax, imax], [jmin, j, jmin, jmin, j]])
assert Mmax, f'Wrong logic above!{imin, imax, jmin, jmax, h, w}'
return Mmax
# generate input array
h, w = 400, 500
arr = np.random.randint(0, 256, h * w).reshape(h, w)
tic = timer()
grayPool = MaxPool_special(3, 3, arr)
toc = timer()
print(f'time cost for for-loops: {toc - tic}')
请帮我加速这段代码,谢谢!