我有一个图像,存储在一个uint8形状为 s的 numpy 数组中(planes, rows, cols)。我需要将它与存储在掩码中的值进行比较,掩码也是uint8s,形状为(mask_rows, mask_cols)。虽然图像可能非常大,但遮罩通常很小,通常(256, 256)要平铺image。为了简化代码,让我们假设rows = 100 * mask_rows和cols = 100 * mask_cols.
我目前处理这个阈值的方式是这样的:
out = image >= np.tile(mask, (image.shape[0], 100, 100))
在被 a扇耳光之前,我可以通过这种方式处理的最大数组MemoryError比(3, 11100, 11100). 我想它的方式,以这种方式做事,我有多达三个巨大的数组共存于内存中:image, tiledmask和 my return out。但是平铺的蒙版是相同的小数组,被复制了 10,000 多次。因此,如果我可以节省内存,我将只使用 2/3 的内存,并且应该能够处理大 3/2 的图像,因此大小约为(3, 13600, 13600). 顺便说一句,如果我使用
np.greater_equal(image, (image.shape[0], 100, 100), out=image)
我(失败)尝试利用周期性特性mask来处理更大的数组是mask用周期性线性数组索引:
mask = mask[None, ...]
rows = np.tile(np.arange(mask.shape[1], (100,))).reshape(1, -1, 1)
cols = np.tile(np.arange(mask.shape[2], (100,))).reshape(1, 1, -1)
out = image >= mask[:, rows, cols]
对于小型阵列,它确实产生与另一个阵列相同的结果,尽管速度降低了 20 倍(!!!),但对于较大的尺寸,它非常无法执行。它最终不会MemoryError导致 python 崩溃,即使对于其他方法处理没有问题的值也是如此。
我认为正在发生的事情是 numpy 实际上是在构造(planes, rows, cols)要索引的数组mask,所以不仅没有节省内存,而且由于它是一个int32s 数组,它实际上需要四倍的空间来存储......
关于如何解决这个问题的任何想法?为了省去你的麻烦,在下面找到一些沙盒代码来玩:
import numpy as np
def halftone_1(image, mask) :
    return np.greater_equal(image, np.tile(mask, (image.shape[0], 100, 100)))
def halftone_2(image, mask) :
    mask = mask[None, ...]
    rows = np.tile(np.arange(mask.shape[1]),
                   (100,)).reshape(1, -1, 1)
    cols = np.tile(np.arange(mask.shape[2]),
                   (100,)).reshape(1, 1, -1)
    return np.greater_equal(image, mask[:, rows, cols])
rows, cols, planes = 6000, 6000, 3
image = np.random.randint(-2**31, 2**31 - 1, size=(planes * rows * cols // 4))
image = image.view(dtype='uint8').reshape(planes, rows, cols)
mask = np.random.randint(256,
                         size=(1, rows // 100, cols // 100)).astype('uint8')
#np.all(halftone_1(image, mask) == halftone_2(image, mask))
#halftone_1(image, mask)
#halftone_2(image, mask)
import timeit
print timeit.timeit('halftone_1(image, mask)',
                    'from __main__ import halftone_1, image, mask',
                    number=1)
print timeit.timeit('halftone_2(image, mask)',
                    'from __main__ import halftone_2, image, mask',
                    number=1)