给定一个 numpy 布尔数组
arr = np.array([1, 0, 1, 1, 0, 0, 1, 1, 1, 0, 1])
我想指出至少有n
连续真实值的位置(从左到右)。
对于n = 2
:
# True 2x (or more) in a row
# / \ / \
arr = [1, 0, 1, 1, 0, 0, 1, 1, 1, 0, 1]
# becomes:
res = [0, 0, 0, 1, 0, 0, 0, 1, 1, 0, 0]
# ^-----------^--^-------- A pattern of 2 or more consecutive True's ends at each of these locations
对于n = 3
:
# True 3x (or more) in a row
# / \
arr = [1, 0, 1, 1, 0, 0, 1, 1, 1, 0, 1]
# becomes:
res = [0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0]
# ^-------- A pattern of 3 or more consecutive True's ends at this location
在不使用 for 循环遍历每个元素的情况下,是否有一种 Pythonic 方法?性能在这里很重要,因为我的应用程序将在包含 1000 个元素的布尔数组上执行此操作 1000 次。
值得一提的注意事项:
- n可以是大于 2 的任何值
- n 个连续的模式可以出现在数组的任何位置;开头、中间或结尾。
- 结果数组的形状必须与原始数组的形状相匹配。
任何帮助将不胜感激。
答案的基准
fully vectorized by alain-t:
10000 loops, best of 5: 0.138 seconds per loop, worse of 5: 0.149 seconds per loop
pad/shift by mozway:
10000 loops, best of 5: 1.62 seconds per loop, worse of 5: 1.71 seconds per loop
sliding_window_view by kubatucka (with padding by mozway):
10000 loops, best of 5: 1.15 seconds per loop, worse of 5: 1.54 seconds per loop