7

我使用 tfrecords 存储数据,并使用 API 将它们作为张量读取Dataset,然后使用EstimatorAPI 执行训练。现在,我想对数据集中的每个项目进行在线数据增强,但尝试了一段时间后,我找不到出路。我想要随机翻转、随机旋转等机械手。

我正在按照教程中给出的说明使用自定义估计器,这是我的 CNN,但我不确定数据增强步骤发生在哪里。

4

1 回答 1

6

使用 TFRecords 不会阻止您进行数据扩充。

按照您在评论中链接的教程,大致情况如下:

  • 您从 TFRecords 文件创建数据集,并解析文件以获取一个image和一个label
dataset = tf.data.TFRecordDataset(filenames=filenames)
dataset = dataset.map(parse)
  • 您现在可以应用新的预处理功能在训练期间进行一些数据扩充
# Only do it when we are training
if train:
    dataset = dataset.map(train_preprocess)
  • train_preprocess函数可以是这样的:
def train_preprocess(image, label):
    flip_image = tf.image.random_flip_left_right(image)
    # Other transformations...
    return flip_image, label
于 2018-01-19T18:54:11.983 回答