4

我有一个像这样定义的 3D 复杂 numpy 数组:

> import numpy as np
> a = np.random.rand(2,3,4) + np.random.rand(2,3,4) * 1j
> a
array([[[ 0.40506245+0.68587874j,  0.74700976+0.73208816j,
      0.42010818+0.31124884j,  0.27181199+0.54599156j],
    [ 0.29457621+0.34057513j,  0.82490182+0.63943948j,
      0.46887722+0.12734375j,  0.77184637+0.21522095j],
    [ 0.67774944+0.8208908j ,  0.41476702+0.85332392j,
      0.10084665+0.56146324j,  0.71325041+0.77306548j]],

   [[ 0.77843387+0.23660274j,  0.23671262+0.63997834j,
      0.60831419+0.41741288j,  0.53870756+0.13747055j],
    [ 0.12477767+0.54603678j,  0.60537090+0.89208227j,
      0.16027151+0.17575777j,  0.18801875+0.27282324j],
    [ 0.82308271+0.97238411j,  0.47458327+0.75200695j,
      0.16085009+0.60620705j,  0.79766571+0.76470634j]]])

我需要将它打印 s成特定格式的字符串,这有点像 MATLAB,我发现的最佳方法如下:(对我来说,描述格式的最佳方法是使用此代码)

> s = ''
> for k in range(a.shape[2]):
>   for j in range(a.shape[1]):
>     for i in range(a.shape[0]):
>       s += str(a[i,j,k].real) + ' '
>   for j in range(a.shape[1]):
>     for i in range(a.shape[0]):
>       s += str(a[i,j,k].imag) + ' '

我对这段看起来不太“pythonic”的代码不满意(我来自 C++,对 Python 不太了解)。我确信 Python 提供了一些可以在这里使用的好语法(例如列表推导),但我对它不是很熟悉。

因此,我的问题如下:如何改进此代码以使其更加 Pythonic?

编辑:这个 3D 数组被视为 2×3 复杂矩阵的数组。该格式包括打印第一个矩阵的实部,然后是虚部,然后以这种方式遍历每个矩阵。

这是您在 MATLAB 中运行此代码时获得的格式:

> a = rand(2,3,4) + rand(2,3,4) * 1i;
> s = sprintf('%g %g ', [real(a) imag(a)]);

我的主要目标是与这种格式兼容。

4

2 回答 2

3

字符串连接通常使用 join 完成:

s += str(a[i,j,k].imag) + ' '

可以替换为

s += ' '.join(str(a[i,j,k].imag))

在全球范围内应用,1-liner 可以是:

s = ' '.join(' '.join(str(a[i,j,k].real) for j in range(a.shape[1]) for i in range(a.shape[0])) + ' ' + ' '.join(str(a[i,j,k].imag) for j in range(a.shape[1]) for i in range(a.shape[0])) for k in range(a.shape[2]))

不是很清楚。我会保留 for ... k 循环并这样写:

s = ''
for k in range(a.shape[2]):
    s += ' '.join(str(a[i,j,k].real) for j in range(a.shape[1]) for i in range(a.shape[0]))
    s += ' '
    s += ' '.join(str(a[i,j,k].imag) for j in range(a.shape[1]) for i in range(a.shape[0]))
    s += ' '

编辑

这很重,numpy 有很多工具。这是一个更简单的版本。第一行重新格式化矩阵以简化第二行的工作:

b = [numpy.vstack((a.real.T[i], a.imag.T[i])) for i in range(a.shape[2])]
s = ' '.join(str(d) for x in b for d in x.flat)

编辑 2

还是可以简化的

' '.join([str(x) for x in np.hstack((a.T.real, a.T.imag)).flat])
于 2013-06-03T14:32:47.737 回答
1

有了足够的思考,您应该能够避免创建中间副本。但既然人生苦短,那又如何:

' '.join(np.hstack([a.T.real, a.T.imag]).astype(str).flat)

例如:

>>> a
array([[[ 0.75878533+0.6450401j ,  0.97544304+0.95294337j,
          0.72619451+0.70150035j,  0.53653874+0.72336166j],
        [ 0.44497093+0.59486404j,  0.48346416+0.602289j  ,
          0.89508307+0.10804834j,  0.60925276+0.78463914j],
        [ 0.75324059+0.35750314j,  0.77764455+0.52714092j,
          0.60422248+0.45825998j,  0.06100151+0.98814297j]],

       [[ 0.25167445+0.26036597j,  0.14479218+0.63888545j,
          0.69195476+0.65571239j,  0.75384667+0.35208925j],
        [ 0.33299320+0.95810933j,  0.28706287+0.92696162j,
          0.80174074+0.73461441j,  0.64070651+0.95546677j],
        [ 0.32726129+0.28131131j,  0.84847281+0.0043481j ,
          0.20002495+0.92129643j,  0.85657582+0.17598515j]]])
>>> new = ' '.join(np.hstack([a.T.real, a.T.imag]).astype(str).flat)
>>> new
'0.758785326622 0.251674447258 0.444970928938 0.332993197954 0.753240586102 0.3272612899 0.645040097487 0.260365974319 0.59486403781 0.958109327206 0.357503144442 0.281311309104 0.975443036171 0.14479217684 0.483464161328 0.287062874161 0.777644547623 0.84847280757 0.952943365086 0.638885451204 0.602289004931 0.926961617163 0.527140924938 0.00434810439813 0.726194510838 0.691954756116 0.895083070782 0.801740737909 0.604222482831 0.200024953365 0.701500350108 0.655712387542 0.108048340908 0.734614410363 0.458259975834 0.921296429741 0.536538738872 0.75384667023 0.609252761053 0.640706514463 0.0610015096191 0.856575822125 0.723361662643 0.35208924756 0.784639135069 0.955466768932 0.988142972679 0.175985147504'
>>> original(a).strip() == new
True

更新:如果.astype(str)由于某种原因无法正常工作,则作为后备:

>>> new2 = ' '.join(map(str, np.hstack([a.T.real, a.T.imag]).flat))
>>> original(a).strip() == new2
True
于 2013-06-03T15:21:37.730 回答