我正在根据本指南运行 tensorflow 对象检测 api https://tensorflow-object-detection-api-tutorial.readthedocs.io/en/latest/training.html#configuring-a-training-job但是稍微有点修改了用于制作记录文件的代码并使用以下系统:
系统信息:
- 操作系统平台和发行版:Ubuntu 20.04.1 LTS
- 蟒蛇版本:
- TensorFlow 版本:2.4.0
- CUDA/cuDNN 版本:11.0/8.0.5
- GPU 型号和内存:GeForce RTX 3090、24268 MiB
我想将模型 CenterNet MobileNetV2 FPN 512x512 用于来自 TensorFlow2 检测模型动物园(https://github.com/tensorflow/models/blob/master/research/object_detection/g3doc/tf2_detection_zoo.md)的框。
我根据然后运行的指南设置了培训工作
python model_main_tf2.py --model_dir=models/my_centernet_mn_fpn --pipeline_config_path=models/my_centernet_mn_fpn/pipeline.config
这样做时我收到以下错误
Traceback (most recent call last):
File "model_main_tf2.py", line 115, in <module>
tf.compat.v1.app.run()
File "/opt/anaconda3/envs/tf/lib/python3.8/site-packages/tensorflow/python/platform/app.py", line 40, in run
_run(main=main, argv=argv, flags_parser=_parse_flags_tolerate_undef)
File "/opt/anaconda3/envs/tf/lib/python3.8/site-packages/absl/app.py", line 300, in run
_run_main(main, args)
File "/opt/anaconda3/envs/tf/lib/python3.8/site-packages/absl/app.py", line 251, in _run_main
sys.exit(main(argv))
File "model_main_tf2.py", line 106, in main
model_lib_v2.train_loop(
File "/opt/anaconda3/envs/tf/lib/python3.8/site-packages/object_detection/model_lib_v2.py", line 636, in train_loop
loss = _dist_train_step(train_input_iter)
File "/opt/anaconda3/envs/tf/lib/python3.8/site-packages/tensorflow/python/eager/def_function.py", line 828, in __call__
result = self._call(*args, **kwds)
File "/opt/anaconda3/envs/tf/lib/python3.8/site-packages/tensorflow/python/eager/def_function.py", line 888, in _call
return self._stateless_fn(*args, **kwds)
File "/opt/anaconda3/envs/tf/lib/python3.8/site-packages/tensorflow/python/eager/function.py", line 2942, in __call__
return graph_function._call_flat(
File "/opt/anaconda3/envs/tf/lib/python3.8/site-packages/tensorflow/python/eager/function.py", line 1918, in _call_flat
return self._build_call_outputs(self._inference_function.call(
File "/opt/anaconda3/envs/tf/lib/python3.8/site-packages/tensorflow/python/eager/function.py", line 555, in call
outputs = execute.execute(
File "/opt/anaconda3/envs/tf/lib/python3.8/site-packages/tensorflow/python/eager/execute.py", line 59, in quick_execute
tensors = pywrap_tfe.TFE_Py_Execute(ctx._handle, device_name, op_name,
tensorflow.python.framework.errors_impl.InvalidArgumentError: 2 root error(s) found.
(0) Invalid argument: indices[0] = 0 is not in [0, 0)
[[{{node GatherV2_7}}]]
[[MultiDeviceIteratorGetNextFromShard]]
[[RemoteCall]]
[[IteratorGetNext]]
(1) Invalid argument: indices[0] = 0 is not in [0, 0)
[[{{node GatherV2_7}}]]
[[MultiDeviceIteratorGetNextFromShard]]
[[RemoteCall]]
[[IteratorGetNext]]
[[ToAbsoluteCoordinates_42/Assert/AssertGuard/branch_executed/_386/_1231]]
0 successful operations.
0 derived errors ignored. [Op:__inference__dist_train_step_54439]
Function call stack:
_dist_train_step -> _dist_train_step
当谷歌搜索此错误时,会出现一些答案,即错误是在创建 TFRecord 文件时,您需要include_masks
在创建它们时添加。但是,当从模型动物园运行其他 CenterNet 模型时,我没有收到此错误,所以这似乎很奇怪,这将是错误。
任何想法,如果它可以是别的什么?