我尝试使用 tf.data.Dataset.interleave 复制此处发布的解决方案,但不太确定如何将 interleave 方法应用于已创建的数据集对象。这是代码:
import tensorflow as tf
import numpy as np
# preparing data
train, test = tf.keras.datasets.fashion_mnist.load_data()
images, labels = train
images = images/255
dataset = tf.data.Dataset.from_tensor_slices((images, labels))
class0=lambda features, label: label==0
class1=lambda features, label: label==1
class2=lambda features, label: label==2
ds_0=dataset.filter(class0)
ds_1=dataset.filter(class1)
ds_2=dataset.filter(class2)
我想通过从 ds_0、ds_1 和 ds_2 中同等采样来创建数据集。我应该通过map_func
什么?