假设我有 3 个 tfrecord 文件,即neg.tfrecord
, pos1.tfrecord
, pos2.tfrecord
.
我用
dataset = tf.data.TFRecordDataset(tfrecord_file)
此代码创建 3 个 Dataset 对象。
我的批量大小是 400,包括 200 个 neg 数据、100 个 pos1 数据和 100 个 pos2 数据。如何获得所需的数据集?
我将在 keras.fit()(急切执行)中使用这个数据集对象。
我的 tensorflow 版本是 1.13.1。
之前尝试获取每个数据集的迭代器,拿到数据后手动concat,但是效率低,GPU利用率也不高。