1

我一直在 keras,tensorflow 1 中使用自定义数据生成器,

model.fit_generator(generator=training_generator,
                    validation_data=validation_generator, 
                    use_multiprocessing=True, epochs=epochs,
                    workers=workers, callbacks=callbacks_list, verbose=2)

效果很好。现在,当我切换到 tensorflow2 时,我发现不再支持 multi_gpu_model(model)。

正如文档中所建议的那样,我切换到 tf.distribute.MirroredStrategy(),因为我在具有 4 个 GPU 的无头服务器上运行。我还将生成器('training_generator')切换为 tf.data.Dataset 格式:

train_ds = tf.data.Dataset.from_generator(lambda: training_generator,
                                      output_types=((tf.float32, tf.float32, tf.float32, tf.float32), tf.float32),
                                      output_shapes=(([None, 224, 224, 3],
                                                    [None, 625],
                                                    [None, 224, 224, 3],
                                                    [None, 224, 224, 3]),
                                                    [None, 2])
                                      )

但是如何让它与多个线程一起运行呢?这是我尝试过的(都来自这里:https ://medium.com/@nimatajbakhsh/building-multi-threaded-custom-data-pipelines-for-tensorflow-f76e9b1a32f5 ):

  1. 用“地图”包裹它。这可行,但仍然在单线程中运行,因为我的 CPU 没有完全加载并且 GPU 正在挨饿。

train_dataset = train_ds.map(lambda x,y: (x,y), num_parallel_calls=workers)

  1. 使用“交错”

generators = tf.data.Dataset.from_tensor_slices(['Gen_0', 'Gen_1', 'Gen_2', 'Gen_3', 'Gen_4', 'Gen_5', 'Gen_6', 'Gen_7','Gen_8','Gen_9', 'Gen_10'])

train_dataset = generators.interleave(lambda x: tf.data.Dataset.from_generator(lambda: training_generator,
                                      output_types=((tf.float32, tf.float32, tf.float32, tf.float32), tf.float32),
                                      output_shapes=(([None, 224, 224, 3],
                                                    [None, 625],
                                                    [None, 224, 224, 3],
                                                    [None, 224, 224, 3]),
                                                    [None, 2])
                                      ),
                    
                    num_parallel_calls=tf.data.experimental.AUTOTUNE)

` 这会加载 CPU 并很好地为 GPU 提供数据,但它似乎会创建我的数据集的副本。我想要的只是遍历整个数据集一次,像以前在 model.fit_generator() 中那样并行生成批次

任何帮助或见解表示赞赏!

4

0 回答 0