10

我试图了解如何使用 nditer 进行缩减,在我的情况下将 3d 数组转换为 2d 数组。

我按照这里的帮助 http://docs.scipy.org/doc/numpy/reference/arrays.nditer.html并设法创建了一个在输入的最后一个轴上应用归约的函数。有了这个功能

def nditer_sum(data, red_axes):
    it = numpy.nditer([data, None],
            flags=['reduce_ok', 'external_loop'],
            op_flags=[['readonly'], ['readwrite', 'allocate']],
            op_axes=[None, red_axes])
    it.operands[1][...] = 0

    for x, y in it:
        y[...] = x.sum()

    return it.operands[1]

我可以得到相当于 data.sum(axis=2) 的东西

>>> data = numpy.arange(2*3*4).reshape((2,3,4))
>>> nditer_sum(data, [0, 1, -1])
[[ 6 22 38]
[54 70 86]]
>>> data.sum(axis=2)
[[ 6 22 38]
[54 70 86]]

因此,要获得与 data.sum(axis=0) 等效的内容,尽管将参数 red_axes 更改为 [-1, 0,1] 就足够了,但结果却大不相同。

>>> data = numpy.arange(2*3*4).reshape((2,3,4))
>>> data.sum(axis=0)
[[12 14 16 18]
 [20 22 24 26]
 [28 30 32 34]]
>>> nditer_sum(data, [-1, 0, 1])
[[210 210 210 210]
 [210 210 210 210] 
 [210 210 210 210]]

在 nditer_sum (for x,y in it:) 内部的 for 循环中,迭代器循环 2 次,每次给出一个长度为 12 的数组,而不是循环 12 次,每次给出一个长度为 2 的数组。我已经多次阅读了 numpy 文档,并在谷歌上对此无济于事。我正在使用 numpy 1.6 和 python 2.7

4

2 回答 2

1

axis=0如果nditer订单更改为 ,则案例正常工作F。现在有 12 个步骤,其中包含您想要的大小为 (2,) 的数组。

it = np.nditer([data, None],
        flags=['reduce_ok', 'external_loop'],
        op_flags=[['readonly'], ['readwrite', 'allocate']],
        order='F',     # ADDED to loop starting with the last dimension
        op_axes=[None, red_axes])

但是对于中间axis=1情况,没有这样的解决方案。


对选定维度进行迭代的另一种方法是在降维数组上构造一个“multi_index”迭代器。我在https://stackoverflow.com/a/25097271/901925中发现 np.ndindex使用这个技巧来执行“浅迭代”。

对于这种axis=0情况,此功能有效:

def sum_1st(data):
    y = np.zeros(data.shape[1:], data.dtype)
    it = np.nditer(y, flags=['multi_index'])
    while not it.finished:
        xindex = tuple([slice(None)]+list(it.multi_index))
        y[it.multi_index] = data[xindex].sum()
        it.iternext()
    return y

或推广到任何轴:

def sum_any(data, axis=0):
    yshape = list(data.shape)
    del yshape[axis]
    y = np.zeros(yshape, data.dtype)
    it = np.nditer(y, flags=['multi_index'])
    while not it.finished:
        xindex = list(it.multi_index)
        xindex.insert(axis, slice(None))
        y[it.multi_index] = data[xindex].sum()
        it.iternext()
    return y
于 2014-11-04T06:49:18.347 回答
0

更改线路y[...] = x.sum()y[...] += x修复它(如此处的示例所示

于 2012-11-12T19:43:09.847 回答