我有一个产生特征和目标张量的函数。例如
x,t = myfunc() ##x,t tensors
如何将其与 TensorFlow 的数据集 API 集成以进行持续训练?理想情况下,我想使用数据集来设置批处理、转换等内容。
编辑澄清:问题是我不仅想将 x 和 t 放在我的图表中,还想从中创建一个数据集,以便我可以使用我为(正常)有限数据集实现的相同数据集处理,我可以加载到内存中并使用可初始化的迭代器输入同一个图形。
我有一个产生特征和目标张量的函数。例如
x,t = myfunc() ##x,t tensors
如何将其与 TensorFlow 的数据集 API 集成以进行持续训练?理想情况下,我想使用数据集来设置批处理、转换等内容。
编辑澄清:问题是我不仅想将 x 和 t 放在我的图表中,还想从中创建一个数据集,以便我可以使用我为(正常)有限数据集实现的相同数据集处理,我可以加载到内存中并使用可初始化的迭代器输入同一个图形。
假设x
和t
是tf.Tensor
对象,并my_func()
构建一个 TensorFlow 图,您可以使用以下方法与 `Dataset.map():
# Creates an infinite dataset with a dummy value. You can make this finite by
# specifying an explicit number of elements to `repeat()`.
dummy_dataset = tf.data.Dataset.from_tensors(0).repeat(None)
# Evaluates `my_func` once for each element in `dummy_dataset`.
dataset = dummy_dataset.map(lambda _: my_func())
tf.data.Dataset.from_tensors
如果 x 和 t 是张量,您可以通过调用或创建数据集(此处tf.data.Dataset.from_tensor_slices
的文档)。
它们之间的区别在于from_tensors
将输入张量组合成数据集中的单个元素。from_tensor_slices
为每个切片创建一个包含一个元素的数据集。