4

我正在自学 Python mpi4py 模块,用于在多个进程中进行编程。我编写了以下代码来练习分散。

from mpi4py import MPI

comm = MPI.COMM_WORLD
size = comm.Get_size()
rank = comm.Get_rank()

if rank == 0:
   data = [i for i in range(8)]
else:
   data = None
data = comm.scatter(data, root=0)
print str(rank) + ': ' + str(data)

用 8 个进程运行上述代码效果很好。但是,当我使用 4 个进程运行它时,出现错误:

Traceback (most recent call last):
  File "scatter.py", line 11, in <module>
    data = comm.scatter(data, root=0)
  File "Comm.pyx", line 874, in mpi4py.MPI.Comm.scatter (src/mpi4py.MPI.c:68023)
  File "pickled.pxi", line 656, in mpi4py.MPI.PyMPI_scatter (src/mpi4py.MPI.c:32402)
  File "pickled.pxi", line 127, in mpi4py.MPI._p_Pickle.dumpv (src/mpi4py.MPI.c:26813)
ValueError: expecting 4 items, got 8

这个错误是什么意思?我的意图是将我的 8 个项目的大数组分解为 8 / 4 = 2 个项目的小数组,并向每个进程发送一个这样的子数组。我怎么做?如果可能的话,我还想概括为不均匀分成 8 个的进程数,例如 3 个。

4

1 回答 1

11

似乎comm.scatter不能count作为参数,并期望将精确comm.size元素列表作为数据分散;因此您需要自己在进程之间分配数据。这样的事情会做:

if rank == 0:
    data = [i for i in range(8)]
# dividing data into chunks
    chunks = [[] for _ in range(size)]
    for i, chunk in enumerate(data):
        chunks[i % size].append(chunk)
else:
    data = None
    chunks = None
data = comm.scatter(chunks, root=0)
print str(rank) + ': ' + str(data)

[physics@tornado] ~/utils> mpirun -np 3 ./mpi.py 
2: [2, 5]
0: [0, 3, 6]
1: [1, 4, 7]

希望这可以帮助。

于 2012-10-10T09:13:51.387 回答