我想使用 tf.data.Dataset API 在 Tensorflow 2.0 中实现一个自定义小批量生成器。具体来说,我有图像数据,100 个类,每个类约 200 个示例。对于每个 mini-batch,我想随机抽取 P 个类别和每个类别的 K 个图像,以获得 mini-batch 中的总共 P*K 个示例(如论文In Defense of the Triplet Loss for Person Re-识别])。
我一直在搜索tf.data.Dataset的文档,但似乎找不到正确的方法。我已经研究了该from_generator
方法,但它似乎不适合这个,因为据我所知,它从头开始生成整个数据集。
在我看来,一种方法是创建一个类似于BatchDataset
可以在tf.data.Dataset 源代码中找到的新类,我会在其中以某种方式实现逻辑,但我希望有一个更简单的解决方案说实话。