我使用 tfrecords 存储数据,并使用 API 将它们作为张量读取Dataset
,然后使用Estimator
API 执行训练。现在,我想对数据集中的每个项目进行在线数据增强,但尝试了一段时间后,我找不到出路。我想要随机翻转、随机旋转等机械手。
我正在按照本教程中给出的说明使用自定义估计器,这是我的 CNN,但我不确定数据增强步骤发生在哪里。
我使用 tfrecords 存储数据,并使用 API 将它们作为张量读取Dataset
,然后使用Estimator
API 执行训练。现在,我想对数据集中的每个项目进行在线数据增强,但尝试了一段时间后,我找不到出路。我想要随机翻转、随机旋转等机械手。
我正在按照本教程中给出的说明使用自定义估计器,这是我的 CNN,但我不确定数据增强步骤发生在哪里。
使用 TFRecords 不会阻止您进行数据扩充。
按照您在评论中链接的教程,大致情况如下:
image
和一个label
dataset = tf.data.TFRecordDataset(filenames=filenames)
dataset = dataset.map(parse)
# Only do it when we are training
if train:
dataset = dataset.map(train_preprocess)
train_preprocess
函数可以是这样的:def train_preprocess(image, label):
flip_image = tf.image.random_flip_left_right(image)
# Other transformations...
return flip_image, label