1

假设我有 3 个 tfrecord 文件,即neg.tfrecord, pos1.tfrecord, pos2.tfrecord.

我用

dataset = tf.data.TFRecordDataset(tfrecord_file)

此代码创建 3 个 Dataset 对象。

我的批量大小是 400,包括 200 个 neg 数据、100 个 pos1 数据和 100 个 pos2 数据。如何获得所需的数据集?

我将在 keras.fit()(急切执行)中使用这个数据集对象。

我的 tensorflow 版本是 1.13.1。

之前尝试获取每个数据集的迭代器,拿到数据后手动concat,但是效率低,GPU利用率也不高。

4

2 回答 2

1

您可以使用interleave

filenames = [tfrecord_file1, tfrecord_file2]
dataset = (Dataset.from_tensor_slices(filenames).interleave(lambda x:TFRecordDataset(x)
dataset = dataset.map(parse_fn)
...

或者您甚至可以尝试并行交错。见https://www.tensorflow.org/api_docs/python/tf/data/TFRecordDataset#interleave https://www.tensorflow.org/api_docs/python/tf/data/experimental/parallel_interleave

于 2019-03-14T06:05:45.153 回答
0

这适用于我目前在 Kaggle 上执行的一个项目。
我读了 5 个不同年份的数据集,并使用以下代码合并它们。
祝福-约翰-埃里克

帧=[df,df1,df2,df3,df4]

数据 = pd.concat(帧)

数据

于 2022-03-05T23:26:04.180 回答