7

如何根据实际索引值屏蔽数组?

也就是说,如果我有一个 10 x 10 x 30 矩阵,并且我想在第一个和第二个索引彼此相等时屏蔽数组。

例如,[1, 1 , :] 应该被屏蔽,因为 1 和 1 彼此相等,但[1, 2, :]不应该因为它们不相等。

我只是用第三维来问这个问题,因为它类似于我当前的问题并且可能会使事情复杂化。但我的主要问题是,如何根据索引的值屏蔽数组?

4

2 回答 2

7

通常,要访问索引的值,您可以使用np.meshgrid

i, j, k = np.meshgrid(*map(np.arange, m.shape), indexing='ij')
m.mask = (i == j)

这种方法的优点是它适用于 、 和 上的任意i布尔j函数k。它比使用identity特殊情况要慢一些。

In [56]: %%timeit
   ....: i, j, k = np.meshgrid(*map(np.arange, m.shape), indexing='ij')
   ....: i == j
10000 loops, best of 3: 96.8 µs per loop

正如@Jaime 指出的那样,meshgrid支持一个sparse选项,它不会做太多重复,但在某些情况下需要多加注意,因为它们不广播。它将节省内存并加快速度。例如,

In [77]: x = np.arange(5)

In [78]: np.meshgrid(x, x)
Out[78]: 
[array([[0, 1, 2, 3, 4],
       [0, 1, 2, 3, 4],
       [0, 1, 2, 3, 4],
       [0, 1, 2, 3, 4],
       [0, 1, 2, 3, 4]]),
 array([[0, 0, 0, 0, 0],
       [1, 1, 1, 1, 1],
       [2, 2, 2, 2, 2],
       [3, 3, 3, 3, 3],
       [4, 4, 4, 4, 4]])]

In [79]: np.meshgrid(x, x, sparse=True)
Out[79]: 
[array([[0, 1, 2, 3, 4]]),
 array([[0],
       [1],
       [2],
       [3],
       [4]])]

所以,你可以使用sparse他所说的版本,但你必须强制广播:

i, j, k = np.meshgrid(*map(np.arange, m.shape), indexing='ij', sparse=True)
m.mask = np.repeat(i==j, k.size, axis=2)

和加速:

In [84]: %%timeit
   ....: i, j, k = np.meshgrid(*map(np.arange, m.shape), indexing='ij', sparse=True)
   ....: np.repeat(i==j, k.size, axis=2)
10000 loops, best of 3: 73.9 µs per loop
于 2013-09-17T22:28:42.840 回答
0

在您想要屏蔽对角线的特殊情况下,您可以使用np.identity()沿对角线返回一个的函数。由于您有第三维,我们必须将该第三维添加到单位矩阵:

m.mask = np.identity(10)[...,None]*np.ones((1,1,30))

构建该数组可能有更好的方法,但它基本上是堆叠 30 个np.identity(10)数组。例如,这是等价的:

np.dstack((np.identity(10),)*30)

但更慢:

In [30]: timeit np.identity(10)[...,None]*np.ones((1,1,30))
10000 loops, best of 3: 40.7 µs per loop

In [31]: timeit np.dstack((np.identity(10),)*30)
1000 loops, best of 3: 219 µs per loop

还有@Ophion 的建议

In [33]: timeit np.tile(np.identity(10)[...,None], 30)
10000 loops, best of 3: 63.2 µs per loop

In [71]: timeit np.repeat(np.identity(10)[...,None], 30)
10000 loops, best of 3: 45.3 µs per loop
于 2013-09-17T22:12:48.303 回答