0

我可以在 colab 上运行此文件vit_jax.ipynb并执行训练并运行我的实验,但是当我尝试在我的集群上复制它时,我在下面给出的训练期间遇到错误。但是,计算准确性的前向传递在我的集群上运行良好。

我的集群上有 4 个带有 CUDA10.1 版本的 GTX 1080,并使用 tensorflow==2.4.0 和 jax[cuda101]==0.2.18。我在 docker 容器内将它作为 jupyter notebook 运行。

---------------------------------------------------------------------------
UnfilteredStackTrace                      Traceback (most recent call last)
<ipython-input-57-176d6124ae02> in <module>()
     11   opt_repl, loss_repl, update_rng_repl = update_fn_repl(
---> 12       opt_repl, flax.jax_utils.replicate(step), batch, update_rng_repl)
     13   losses.append(loss_repl[0])

/usr/local/lib/python3.7/dist-packages/jax/_src/traceback_util.py in reraise_with_filtered_traceback(*args, **kwargs)
    182     try:
--> 183       return fun(*args, **kwargs)
    184     except Exception as e:

/usr/local/lib/python3.7/dist-packages/jax/_src/api.py in f_pmapped(*args, **kwargs)
   1638         name=flat_fun.__name__, donated_invars=tuple(donated_invars),
-> 1639         global_arg_shapes=tuple(global_arg_shapes_flat))
   1640     return tree_unflatten(out_tree(), out)

/usr/local/lib/python3.7/dist-packages/jax/core.py in bind(self, fun, *args, **params)
   1620     assert len(params['in_axes']) == len(args)
-> 1621     return call_bind(self, fun, *args, **params)
   1622 

/usr/local/lib/python3.7/dist-packages/jax/core.py in call_bind(primitive, fun, *args, **params)
   1551   tracers = map(top_trace.full_raise, args)
-> 1552   outs = primitive.process(top_trace, fun, tracers, params)
   1553   return map(full_lower, apply_todos(env_trace_todo(), outs))

/usr/local/lib/python3.7/dist-packages/jax/core.py in process(self, trace, fun, tracers, params)
   1623   def process(self, trace, fun, tracers, params):
-> 1624     return trace.process_map(self, fun, tracers, params)
   1625 

/usr/local/lib/python3.7/dist-packages/jax/core.py in process_call(self, primitive, f, tracers, params)
    606   def process_call(self, primitive, f, tracers, params):
--> 607     return primitive.impl(f, *tracers, **params)
    608   process_map = process_call

/usr/local/lib/python3.7/dist-packages/jax/interpreters/pxla.py in xla_pmap_impl(fun, backend, axis_name, axis_size, global_axis_size, devices, name, in_axes, out_axes_thunk, donated_invars, global_arg_shapes, *args)
    636                           ("fingerprint", fingerprint))
--> 637   return compiled_fun(*args)
    638 

/usr/local/lib/python3.7/dist-packages/jax/interpreters/pxla.py in execute_replicated(compiled, backend, in_handler, out_handler, *args)
   1159   input_bufs = in_handler(args)
-> 1160   out_bufs = compiled.execute_sharded_on_local_devices(input_bufs)
   1161   if xla.needs_check_special():

UnfilteredStackTrace: RuntimeError: Internal: external/org_tensorflow/tensorflow/compiler/xla/service/gpu/nccl_utils.cc:203: NCCL operation ncclGroupEnd() failed: unhandled system error: while running replica 0 and partition 0 of a replicated computation (other replicas may have failed as well).

The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.

--------------------

The above exception was the direct cause of the following exception:

RuntimeError                              Traceback (most recent call last)
<ipython-input-57-176d6124ae02> in <module>()
     10 
     11   opt_repl, loss_repl, update_rng_repl = update_fn_repl(
---> 12       opt_repl, flax.jax_utils.replicate(step), batch, update_rng_repl)
     13   losses.append(loss_repl[0])
     14   lrs.append(lr_fn(step))

/usr/local/lib/python3.7/dist-packages/jax/interpreters/pxla.py in execute_replicated(compiled, backend, in_handler, out_handler, *args)
   1158 def execute_replicated(compiled, backend, in_handler, out_handler, *args):
   1159   input_bufs = in_handler(args)
-> 1160   out_bufs = compiled.execute_sharded_on_local_devices(input_bufs)
   1161   if xla.needs_check_special():
   1162     for bufs in out_bufs:

RuntimeError: Internal: external/org_tensorflow/tensorflow/compiler/xla/service/gpu/nccl_utils.cc:203: NCCL operation ncclGroupEnd() failed: unhandled system error: while running replica 0 and partition 0 of a replicated computation (other replicas may have failed as well).


请让我知道是否有人以前遇到过这个问题?或者有什么办法解决这个问题?

4

1 回答 1

1

如果没有更多信息,很难确定,但此错误可能是由 GPU 内存不足引起的。根据您的本地设置,您可以通过增加 XLA 保留的 GPU 内存的比例来解决它,例如通过将XLA_PYTHON_CLIENT_MEM_FRACTION系统变量设置为0.9或类似的高。

或者,您可以尝试在适合本地硬件内存的较小问题上运行代码。

于 2021-08-03T17:32:47.007 回答