0

我正在尝试为对比损失/三元组损失创建一个具有对(正/负)的数据集。

我有大量 10K+ 的类,我为每个类创建了一个 TFrecord,所以我现在有 10K+ TFRecord。到目前为止,似乎我需要将interleave函数与 TFRecord 一起使用block_length=2,以便在同一批次中从同一类中获取至少两个元素。https://www.tensorflow.org/api_docs/python/tf/data/Dataset#interleave

这是它的代码:

FILENAMES = ["1.tfrec", "2.tfrec" ... ]
NUM = len(FILENAMES) # number of classes
batch_size = 2048
dataset = tf.data.Dataset.from_tensor_slices(FILENAMES).interleave(lambda x:
    tf.data.TFRecordDataset(x, num_parallel_reads=AUTO).map(parse_example, num_parallel_calls=-1).shuffle(batch_size).prefetch(batch_size),
    cycle_length=-1, block_length=2, num_parallel_calls=-1).batch(64, drop_remainder=True)
# Compute number of classe in a dataset loop
counts = 0
sets= set()
print("Num classes : ", NUM)
t = tqdm(NUM)
r = len(sets) 

for batch in dataset:
        ind = set(batch[1][0].numpy().tolist())
        for v in ind:
            sets.add(v)

        t.update(len(sets) - r)
        r = len(sets)
        counts += 1
print(len(sets))

但是,我认为我的代码并没有达到我所看到的所有 TFrecord 文件。事实上,我从来没有得到所有可能的课程(10k+)。如果我遍历数据集一次,我大约有 2K 类到达。我的方法有问题吗?

4

0 回答 0