3

使用 Tensorflow 的 Estimator API,我应该在管道中的哪个点执行数据增强?

根据这个官方Tensorflow 指南,执行数据增强的一个地方是input_fn

def parse_fn(example):
  "Parse TFExample records and perform simple data augmentation."
  example_fmt = {
    "image": tf.FixedLengthFeature((), tf.string, ""),
    "label": tf.FixedLengthFeature((), tf.int64, -1)
  }
  parsed = tf.parse_single_example(example, example_fmt)
  image = tf.image.decode_image(parsed["image"])

  # augments image using slice, reshape, resize_bilinear
  #         |
  #         |
  #         |
  #         v
  image = _augment_helper(image)

  return image, parsed["label"]

def input_fn():
  files = tf.data.Dataset.list_files("/path/to/dataset/train-*.tfrecord")
  dataset = files.interleave(tf.data.TFRecordDataset)
  dataset = dataset.map(map_func=parse_fn)
  # ...
  return dataset

我的问题

如果我在内部执行数据增强input_fn,是否parse_fn返回单个示例或包含原始输入图像 + 所有增强变体的批次?如果它应该只返回一个 [augmented] 示例,我如何确保数据集中的所有图像都以其未增强的形式以及所有变体使用?

4

2 回答 2

1

如果您在数据集上使用迭代器,则您的 _augment_helper 函数将在数据集的每个迭代中被调用

将您的代码更改为

  ds_iter = dataset.make_one_shot_iterator()
  ds_iter = ds_iter.get_next()
  return ds_iter

我已经用一个简单的增强功能对此进行了测试

  def _augment_helper(image):
       print(image.shape)
       image = tf.image.random_brightness(image,255.0, 1)
       image = tf.clip_by_value(image, 0.0, 255.0)
       return image

将 255.0 更改为数据集中的最大值,我使用 255.0,因为我的示例数据集是 8 位像素值

于 2019-03-28T13:20:35.053 回答
0

它将为您对 parse_fn 的每次调用返回单个示例,然后如果您使用 .batch() 操作,它将返回一批已解析的图像

于 2020-07-14T22:35:13.917 回答