1

我有两个数组:

a = np.array([[1, 2], [3, 4], [5, 6]])

b = np.array([[1, 1, 1, 3, 3],
              [1, 2, 4, 5, 9],
              [1, 2, 3, 4, 5]])

预期的输出将与数组“a”的形状相匹配,并且是:

array([True, False], [False, True], [True, False])

数组 a 和 b 的第一个维度大小始终匹配(在本例中为 3)。

我希望计算的是每个数组的每个索引(0 到 2,因为这里有 3 个维度)是数组“a”中的每个数字是否存在于数组“b”的相应第二维中。

我可以使用以下代码循环解决这个问题,但我想对其进行矢量化以获得速度提升,但在这里坐了几个小时,我无法弄清楚:

output = np.full(a.shape, False)
assert len(a) == len(b)
for i in range(len(a)):
    output[i] = np.isin(a[i], b[i])

感谢您的任何指导!任何事情都会非常感激:)

4

1 回答 1

2

适当地重塑数组,以便它们在比较时可以正确广播:

(a[...,None] == b[:,None]).any(2)

#[[ True False]
# [False  True]
# [ True False]]
  • a[...,None]在末尾添加一个额外的维度,带有 shape (3, 2, 1)
  • b[:,None]插入一个尺寸作为第二轴,形状为(3, 1, 5);
  • 当您比较两个数组时,两者都将被广播,(3, 2, 5)因此基本上您将行中的a每个元素与相应行中的每个元素进行比较b
  • 最后,您可以检查 ; 中的每个元素是否有任何匹配项a
于 2021-07-24T17:56:30.187 回答