我可以在 colab 上运行此文件vit_jax.ipynb并执行训练并运行我的实验,但是当我尝试在我的集群上复制它时,我在下面给出的训练期间遇到错误。但是,计算准确性的前向传递在我的集群上运行良好。
我的集群上有 4 个带有 CUDA10.1 版本的 GTX 1080,并使用 tensorflow==2.4.0 和 jax[cuda101]==0.2.18。我在 docker 容器内将它作为 jupyter notebook 运行。
---------------------------------------------------------------------------
UnfilteredStackTrace Traceback (most recent call last)
<ipython-input-57-176d6124ae02> in <module>()
11 opt_repl, loss_repl, update_rng_repl = update_fn_repl(
---> 12 opt_repl, flax.jax_utils.replicate(step), batch, update_rng_repl)
13 losses.append(loss_repl[0])
/usr/local/lib/python3.7/dist-packages/jax/_src/traceback_util.py in reraise_with_filtered_traceback(*args, **kwargs)
182 try:
--> 183 return fun(*args, **kwargs)
184 except Exception as e:
/usr/local/lib/python3.7/dist-packages/jax/_src/api.py in f_pmapped(*args, **kwargs)
1638 name=flat_fun.__name__, donated_invars=tuple(donated_invars),
-> 1639 global_arg_shapes=tuple(global_arg_shapes_flat))
1640 return tree_unflatten(out_tree(), out)
/usr/local/lib/python3.7/dist-packages/jax/core.py in bind(self, fun, *args, **params)
1620 assert len(params['in_axes']) == len(args)
-> 1621 return call_bind(self, fun, *args, **params)
1622
/usr/local/lib/python3.7/dist-packages/jax/core.py in call_bind(primitive, fun, *args, **params)
1551 tracers = map(top_trace.full_raise, args)
-> 1552 outs = primitive.process(top_trace, fun, tracers, params)
1553 return map(full_lower, apply_todos(env_trace_todo(), outs))
/usr/local/lib/python3.7/dist-packages/jax/core.py in process(self, trace, fun, tracers, params)
1623 def process(self, trace, fun, tracers, params):
-> 1624 return trace.process_map(self, fun, tracers, params)
1625
/usr/local/lib/python3.7/dist-packages/jax/core.py in process_call(self, primitive, f, tracers, params)
606 def process_call(self, primitive, f, tracers, params):
--> 607 return primitive.impl(f, *tracers, **params)
608 process_map = process_call
/usr/local/lib/python3.7/dist-packages/jax/interpreters/pxla.py in xla_pmap_impl(fun, backend, axis_name, axis_size, global_axis_size, devices, name, in_axes, out_axes_thunk, donated_invars, global_arg_shapes, *args)
636 ("fingerprint", fingerprint))
--> 637 return compiled_fun(*args)
638
/usr/local/lib/python3.7/dist-packages/jax/interpreters/pxla.py in execute_replicated(compiled, backend, in_handler, out_handler, *args)
1159 input_bufs = in_handler(args)
-> 1160 out_bufs = compiled.execute_sharded_on_local_devices(input_bufs)
1161 if xla.needs_check_special():
UnfilteredStackTrace: RuntimeError: Internal: external/org_tensorflow/tensorflow/compiler/xla/service/gpu/nccl_utils.cc:203: NCCL operation ncclGroupEnd() failed: unhandled system error: while running replica 0 and partition 0 of a replicated computation (other replicas may have failed as well).
The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.
--------------------
The above exception was the direct cause of the following exception:
RuntimeError Traceback (most recent call last)
<ipython-input-57-176d6124ae02> in <module>()
10
11 opt_repl, loss_repl, update_rng_repl = update_fn_repl(
---> 12 opt_repl, flax.jax_utils.replicate(step), batch, update_rng_repl)
13 losses.append(loss_repl[0])
14 lrs.append(lr_fn(step))
/usr/local/lib/python3.7/dist-packages/jax/interpreters/pxla.py in execute_replicated(compiled, backend, in_handler, out_handler, *args)
1158 def execute_replicated(compiled, backend, in_handler, out_handler, *args):
1159 input_bufs = in_handler(args)
-> 1160 out_bufs = compiled.execute_sharded_on_local_devices(input_bufs)
1161 if xla.needs_check_special():
1162 for bufs in out_bufs:
RuntimeError: Internal: external/org_tensorflow/tensorflow/compiler/xla/service/gpu/nccl_utils.cc:203: NCCL operation ncclGroupEnd() failed: unhandled system error: while running replica 0 and partition 0 of a replicated computation (other replicas may have failed as well).
请让我知道是否有人以前遇到过这个问题?或者有什么办法解决这个问题?