我刚刚升级到 tensorflow 2.3。我想制作自己的数据生成器进行训练。使用 tensorflow 1.x,我这样做了:
def get_data_generator(test_flag):
item_list = load_item_list(test_flag)
print('data loaded')
while True:
X = []
Y = []
for _ in range(BATCH_SIZE):
x, y = get_random_augmented_sample(item_list)
X.append(x)
Y.append(y)
yield np.asarray(X), np.asarray(Y)
data_generator_train = get_data_generator(False)
data_generator_test = get_data_generator(True)
model.fit_generator(data_generator_train, validation_data=data_generator_test,
epochs=10000, verbose=2,
use_multiprocessing=True,
workers=8,
validation_steps=100,
steps_per_epoch=500,
)
此代码适用于 tensorflow 1.x。在系统中创建了 8 个进程。处理器和视频卡加载完美。“数据加载”打印了 8 次。
使用 tensorflow 2.3 我收到警告:
警告:张量流:多处理可能与 TensorFlow 交互不良,导致非确定性死锁。对于高性能数据管道,建议使用 tf.data。
“数据加载”打印了一次(应该是 8 次)。GPU 没有被充分利用。它也有每个 epoch 的内存泄漏,所以在几个 epoch 后,traning 会停止。use_multiprocessing 标志没有帮助。
如何在 tensorflow(keras) 2.x 中制作可以轻松跨多个 CPU 进程并行化的生成器/迭代器?死锁和数据顺序并不重要。