2

我正在使用张量流 2。

当使用Model.fit()带有 a 的方法时tf.data.Dataset,参数 ' batch_size' 被忽略。因此,要批量训练我的模型,我必须首先通过调用将我的样本数据集更改为批量样本数据集tf.data.Dataset.batch(batch_size)

然后,在阅读文档后,我不清楚该.fit()方法将如何在每个时期对我的数据集进行洗牌。

由于我的数据集是批次数据集,它会在批次之间打乱(批次保持不变)吗?或者它会打乱所有样本,然后将它们重新组合成新的批次(这是所需的行为)

非常感谢你的帮助。

4

1 回答 1

0

使用API时,该shuffle参数对功能没有影响。fittf.data.Dataset

如果我们阅读文档(重点是我的):

shuffle:布尔值(是否在每个 epoch 之前对训练数据进行洗牌)或 str(用于“批处理”)。当 x 是生成器时,此参数将被忽略。'batch' 是处理 HDF5 数据限制的特殊选项;它以批量大小的块进行洗牌。当 steps_per_epoch 不是 None 时无效。

这不是很清楚,但我们可以暗示使用 a 时将忽略 shuffle 参数tf.data.Dataset,因为它的行为类似于生成器。

可以肯定的是,让我们深入研究代码。如果我们查看该fit方法的代码,您会发现数据是由一个特殊的类处理的DataHandler。查看这个类的代码,我们看到这是一个处理不同类型数据的适配器类。我们对处理 tf.data.Dataset, 的类感兴趣,DatasetAdapter我们可以看到这个类没有考虑shuffle参数 :

  def __init__(self,
               x,
               y=None,
               sample_weights=None,
               steps=None,
               **kwargs):
    super(DatasetAdapter, self).__init__(x, y, **kwargs)
    # Note that the dataset instance is immutable, its fine to reuse the user
    # provided dataset.
    self._dataset = x

    # The user-provided steps.
    self._user_steps = steps

    self._validate_args(y, sample_weights, steps)

如果要对数据集进行随机播放,请使用API中的随机播放功能。tf.data.Dataset

于 2020-10-14T15:34:37.190 回答