我想在 Cython 的 prange 循环中并行执行 3 个采用相同输入的函数。他们在相同的变量TV
和上添加值,并du
采用相同的变量。代码的目的是计算四个主要方向上的像素梯度,然后按像素计算总变化。
为此,我创建了一个包含方法名称的列表并遍历该列表。我有这个代码:
cdef void TV_norm(float[:, :] ux, float[:, :] uy, float[:, :] output, float epsilon, float p) nogil:
cdef int M = ux.shape[0]
cdef int N = ux.shape[1]
cdef int i, j
cdef float inv_p = 1./p
cdef float eps = epsilon**p
with parallel(num_threads=64):
for i in prange(M, schedule="guided"):
for j in range(N):
output[i, j] += (abs(ux[i, j])** p + abs(uy[i, j])** p + eps) **inv_p
cdef void center_diff(float[:, :] u, float[:, :] TV, float[:, :] du, int di, int dj, float epsilon, float p):
ux = np.roll(u, (di, 0)) - u
uy = np.roll(u, (0, dj)) - u
TV_norm(ux, uy, TV, epsilon, p)
du -= ux + uy
cdef void i_diff(float[:, :] u, float[:, :] TV, float[:, :] du, int di, int dj, float epsilon, float p):
ux = u - np.roll(u, (-di, 0))
uy = np.roll(u, (-di, dj)) - np.roll(u, (-di, 0))
TV_norm(ux, uy, TV, epsilon, p)
du += ux
cdef void j_diff(float[:, :] u, float[:, :] TV, float[:, :] du, int di, int dj, float epsilon, float p):
ux = np.roll(u, (di, -dj)) - np.roll(u, (0, -dj))
uy = u - np.roll(u, (0, -dj))
TV_norm(ux, uy, TV, epsilon, p)
du += uy
cdef list divTV_dual(float[:, :] u, float epsilon=0, float p=1):
cdef np.ndarray[DTYPE_t, ndim=2] TV = np.zeros_like(u)
cdef np.ndarray[DTYPE_t, ndim=2] du = TV.copy()
cdef list shifts = [[1, 1],[-1, 1],[1,-1],[-1, -1]]
cdef list methods = [center_diff, i_diff, j_diff]
with nogil, parallel(num_threads=4):
for i in prange(4, schedule="static"):
with gil:
di = shifts[i][0]
dj = shifts[i][1]
for j in range(3):
methods[j](u, TV, du, di, dj, epsilon, p)
return [du, TV]
虽然它在纯 Python 中工作,但 Cython 在编译时失败:
/usr/local/lib/python3.5/dist-packages/IPython/core/interactiveshell.py in run_cell_magic(self, magic_name, line, cell)
2129 magic_arg_s = self.var_expand(line, stack_depth)
2130 with self.builtin_trap:
-> 2131 result = fn(magic_arg_s, cell)
2132 return result
2133
<decorator-gen-127> in cython(self, line, cell)
/usr/local/lib/python3.5/dist-packages/IPython/core/magic.py in <lambda>(f, *a, **k)
185 # but it's overkill for just that one bit of state.
186 def magic_deco(arg):
--> 187 call = lambda f, *a, **k: f(*a, **k)
188
189 if callable(arg):
/usr/local/lib/python3.5/dist-packages/Cython/Build/IpythonMagic.py in cython(self, line, cell)
289 build_extension.build_temp = os.path.dirname(pyx_file)
290 build_extension.build_lib = lib_dir
--> 291 build_extension.run()
292 self._code_cache[key] = module_name
293
/usr/lib/python3.5/distutils/command/build_ext.py in run(self)
336
337 # Now actually compile and link everything.
--> 338 self.build_extensions()
339
340 def check_extensions_list(self, extensions):
/usr/lib/python3.5/distutils/command/build_ext.py in build_extensions(self)
445 self._build_extensions_parallel()
446 else:
--> 447 self._build_extensions_serial()
448
449 def _build_extensions_parallel(self):
/usr/lib/python3.5/distutils/command/build_ext.py in _build_extensions_serial(self)
470 for ext in self.extensions:
471 with self._filter_build_errors(ext):
--> 472 self.build_extension(ext)
473
474 @contextlib.contextmanager
/usr/lib/python3.5/distutils/command/build_ext.py in build_extension(self, ext)
530 debug=self.debug,
531 extra_postargs=extra_args,
--> 532 depends=ext.depends)
533
534 # XXX outdated variable, kept here in case third-part code
/usr/lib/python3.5/distutils/ccompiler.py in compile(self, sources, output_dir, macros, include_dirs, debug, extra_preargs, extra_postargs, depends)
572 except KeyError:
573 continue
--> 574 self._compile(obj, src, ext, cc_args, extra_postargs, pp_opts)
575
576 # Return *all* object filenames, not just the ones we just built.
/usr/lib/python3.5/distutils/unixccompiler.py in _compile(self, obj, src, ext, cc_args, extra_postargs, pp_opts)
118 extra_postargs)
119 except DistutilsExecError as msg:
--> 120 raise CompileError(msg)
121
122 def create_static_lib(self, objects, output_libname,
CompileError: command 'x86_64-linux-gnu-gcc' failed with exit status 1
有什么意思吗?
编辑:
这个概念证明有效:
%%cython --compile-args=-O3 --compile-args=-ffast-math --compile-args=-fopenmp --link-args=-fopenmp
# cython: boundscheck=False
# cython: cdivision=True
# cython: wraparound=False
# cython: profile=True
cimport cython
from cython.parallel cimport parallel, prange
cdef foo(a):
print(a)
cdef bar(a):
print(a)
methods = [foo, bar]
cdef int i
with nogil, parallel():
for i in prange(2):
with gil:
methods[i]("a")