1

假设我有 3 个(或 100 个)具有 dim=2 和 shape=(x, y) 的 ndarray,它们相互堆叠。

对于另一个数组下方的数组中的每个索引,与上面的值相比,下面的值更小,如下所示:

A = 
[ 0 0 1 1
  0 0 0 1 
  0 0 0 0 
  0 0 0 0 ]
B = 
[ 2 2 2 2
  2 2 2 2
  1 2 2 2
  1 1 2 2 ]
C =
[ 3 4 4 3
  3 4 4 4
  2 3 4 4
  2 2 2 4 ]

给定一个数字(例如 1.5),我想找到

  for each (x, y) of the ndarrays:
    (1) the index of the stacked array, that has the biggest value below and
    (2) the index of the stacked array, that has the smalest value above the number
that is, the sourunding "bouding layer" of the number)

对于上面的示例数组,这将是: 低于阈值的层的索引

I_biggest_smaller_number = 
[ 0 0 0 0
  0 0 0 0
  1 0 0 0
  1 1 0 0 ]

高于阈值的层数

I_smallest_bigger_number = 
[ 1 1 1 1
  1 1 1 1
  2 1 1 1
  2 2 1 1]

以最有效的方式使用 numpy. 任何帮助将不胜感激:)

4

1 回答 1

1

看来您想使用 NumPy 的maxminwhere函数的组合。

使用numpy.where允许我们根据标准找到矩阵条目的索引。在这种情况下,我们可以询问矩阵条目子集的最大值/最小值(大于或小于给定数字)的值的索引是什么。

这很拗口,但希望这里包含的代码应该有所帮助。不过要小心:在您的示例中,B[3,2]和中的值C[3,2]是相同的。也许这是一个错字;但是,我在下面的代码中对此做了一些假设。

import numpy as np

A = np.array([[0,0,1,1],
              [0,0,0,1],
              [0,0,0,0],
              [0,0,0,0]])

B = np.array([[2,2,2,2],
              [2,2,2,2],
              [1,2,2,2],
              [1,1,2,2]])

C = np.array([[3,4,4,3],
              [3,4,4,4],
              [2,3,4,4],
              [2,2,2,4]])

# I assume the arrays are stacked like this                           
stacked_arrays = np.array([A,B,C])
# So the shape of stacked_arrays in this case is (3,4,4)

n = 1.5 # The example value you gave

I_biggest_smaller_number=np.ndarray((4,4),np.int)
I_smallest_bigger_number=np.ndarray((4,4),np.int)

for x in xrange(stacked_arrays.shape[1]):
    for y in xrange(stacked_arrays.shape[2]):

        # Take values we are interested in; i.e. all values for (x,y)
        temp = stacked_arrays[:,x,y]

        # Find index of maximum value below n
        I_biggest_smaller_number[x,y]=np.where(temp==np.max(temp[np.where(temp<n)]))[0][-1]
# The [-1] takes the highest index if there are duplicates

        # Find index of minimum value above n
        I_smallest_bigger_number[x,y]=np.where(temp==np.min(temp[np.where(temp>n)]))[0][0]
# The [0] takes the lowest index if there are duplicates

print I_biggest_smaller_number
print
print I_smallest_bigger_number
于 2013-05-20T14:04:04.973 回答