给定一个包含 2 个 9x9 图像的数组,其中 2 个通道的形状如下:
img1 = img1 = np.arange(162).reshape(9,9,2).copy()
img2 = img1 * 2
batch = np.array([img1, img2])
我需要将每个图像切成 3x3x2 (stride=3) 区域,然后定位并替换每个切片的最大元素。对于上面的示例,这些元素是:
(:, 2, 2, :)
(:, 2, 5, :)
(:, 2, 8, :)
(:, 5, 2, :)
(:, 5, 5, :)
(:, 5, 8, :)
(:, 8, 2, :)
(:, 8, 5, :)
(:, 8, 8, :)
到目前为止,我的解决方案是这样的:
batch_size, _, _, channels = batch.shape
region_size = 3
# For the (0, 0) region
region_slice = (slice(batch_size), slice(region_size), slice(region_size), slice(channels))
region = batch[region_slice]
new_values = np.arange(batch_size * channels)
# Flatten each channel of an image
region_3d = region.reshape(batch_size, region_size ** 2, channels)
region_3d_argmax = region_3d.argmax(axis=1)
region_argmax = (
np.repeat(np.arange(batch_size), channels),
*np.unravel_index(region_3d_argmax.ravel(), (region_size, region_size)),
np.tile(np.arange(channels), batch_size)
)
# Find indices of max element for each channel
region_3d_argmax = region_3d.argmax(axis=1)
# Manually unravel indices
region_argmax = (
np.repeat(np.arange(batch_size), channels),
*np.unravel_index(region_3d_argmax.ravel(), (region_size, region_size)),
np.tile(np.arange(channels), batch_size)
)
batch[region_slice][region_argmax] = new_values
这段代码有两个问题:
- 重塑
region
可能会返回副本而不是视图 - 手动拆线
执行此操作的更好方法是什么?