2

我对 NumPy 比较陌生。任何人都有让这段代码更紧凑/更高效的想法,尤其是嵌套循环?顺便说一句, dist 和 data 是三维 numpy 数组。

def interpolate_to_distance(self,distance):

    interpolated_data=np.ndarray(self.dist.shape[1:])
    for j in range(interpolated_data.shape[1]):
        for i in range(interpolated_data.shape[0]):
            interpolated_data[i,j]=np.interp(
                                  distance,self.dist[:,i,j],self.data[:,i,j])

    return(interpolated_data)

谢谢!

4

1 回答 1

3

好吧,我会用这个来打个招呼:

def interpolate_to_distance(self, distance):
    dshape = self.dist.shape
    dist = self.dist.T.reshape(-1, dshape[-1])
    data = self.data.T.reshape(-1, dshape[-1])
    intdata = np.array([np.interp(distance, di, da)
                        for di, da in zip(dist, data)])
    return intdata.reshape(dshape[0:2]).T

它至少删除了一个循环(以及那些嵌套索引),但它并没有比原来的快多少,根据%timeitIPython 快 20%。另一方面,有很多(最终可能是不必要的)转置和重塑。

作为记录,我将它包装在一个虚拟类中,并用随机数填充了一些 3 x 3 x 3 数组以进行测试:

import numpy as np

class TestClass(object):
    def interpolate_to_distance(self, distance):
        dshape = self.dist.shape
        dist = self.dist.T.reshape(-1, dshape[-1])
        data = self.data.T.reshape(-1, dshape[-1])
        intdata = np.array([np.interp(distance, di, da)
                            for di, da in zip(dist, data)])
        return intdata.reshape(dshape[0:2]).T

    def interpolate_to_distance_old(self, distance):
        interpolated_data=np.ndarray(self.dist.shape[1:])
        for j in range(interpolated_data.shape[1]):
            for i in range(interpolated_data.shape[0]):
                interpolated_data[i,j]=np.interp(
                           distance,self.dist[:,i,j],self.data[:,i,j])
        return(interpolated_data)

if __name__ == '__main__':
    testobj = TestClass()

    testobj.dist = np.random.randn(3, 3, 3)
    testobj.data = np.random.randn(3, 3, 3)

    distance = 0
    print 'Old:\n', testobj.interpolate_to_distance_old(distance)
    print 'New:\n', testobj.interpolate_to_distance(distance)

哪个打印(对于我的特定随机数):

Old:
[[-0.59557042 -0.42706077  0.94629049]
 [ 0.55509032 -0.67808257 -0.74214045]
 [ 1.03779189 -1.17605275  0.00317679]]
New:
[[-0.59557042 -0.42706077  0.94629049]
 [ 0.55509032 -0.67808257 -0.74214045]
 [ 1.03779189 -1.17605275  0.00317679]]

我也尝试过np.vectorize(np.interp),但无法让它发挥作用。我怀疑如果它确实有效,那会快得多。

我也无法开始np.fromfunction工作,因为它将 (2) 3 x 3(在这种情况下)索引数组传递给np.interp,与您从np.mgrid.

另一个注意事项:根据文档np.interp

np.interp不检查 x 坐标序列xp是否在增加。如果 xp不增加,结果是无稽之谈。一个简单的增长检查是:

np.all(np.diff(xp) > 0)

显然,我的随机数违反了“不断增加”的规则,但您必须更加小心。

于 2012-12-18T04:37:39.763 回答