1

我正在尝试使用 tf.data API 将图像数据集加载到 GPU 上,因为它们提供了优化的性能。但不幸的是,使用 tf.data.Dataset.map() 函数不会返回与 model.fit() 或 model.fit_generator() 兼容的数据集。假设目录树与 keras ImageDataGenerator 所需的目录树相同。

files = [os.path.join(train_dir, file) for file in os.listdir(train_dir)]
val_files = [os.path.join(val_dir, file) for file in os.listdir(val_dir)]


def get_data_train(file_path: str) -> tuple:
    mask_path = tf.strings.regex_replace(file_path, '.jpg$', '.png')
    mask_path = tf.strings.regex_replace(mask_path, 'Images', 'Label', replace_global=False)
    mask = tf.io.read_file(mask_path)
    mask = tf.image.decode_png(mask, channels=1)
    mask = tf.image.resize_image_with_pad(mask, target_height=544, target_width=960)
    image = tf.io.read_file(file_path)
    image = tf.image.decode_jpeg(image, channels=3)
    image = tf.image.resize_image_with_pad(image, target_height=544, target_width=960)
    if tf.random.uniform(shape=(), minval=0, maxval=1, dtype=tf.float32) > 0.5:
        image = tf.image.flip_left_right(image)
        mask = tf.image.flip_left_right(mask)
    image = tf.image.convert_image_dtype(image, dtype=tf.float32)
    mask = tf.image.convert_image_dtype(mask, dtype=tf.float32)
    return image / 255., mask[:, :, 0] / 255.


def get_data_validation(file_path: str) -> tuple:
    mask_path = tf.strings.regex_replace(file_path, '.jpg$', '.png')
    mask_path = tf.strings.regex_replace(mask_path, 'Images', 'Label', replace_global=False)
    mask = tf.io.read_file(mask_path)
    mask = tf.image.decode_png(mask, channels=1)
    mask = tf.image.resize_image_with_pad(mask, target_height=544, target_width=960)
    image = tf.io.read_file(file_path)
    image = tf.image.decode_jpeg(image, channels=3)
    image = tf.image.resize_image_with_pad(image, target_height=544, target_width=960)
    image = tf.image.convert_image_dtype(image, dtype=tf.float32)
    mask = tf.image.convert_image_dtype(mask, dtype=tf.float32)
    return image / 255., mask[:, :, 0] / 255.


def configure_for_performance(dataset: tf.data.Dataset):
    dataset = dataset.cache()
    dataset = dataset.shuffle(buffer_size=8)
    dataset = dataset.batch(args.batch_size)
    dataset = dataset.prefetch(buffer_size=8)
    return dataset


train_ds = tf.data.Dataset.from_tensor_slices(files)
train_ds = train_ds.map(lambda inputs: tf.py_func(get_data_train, [inputs], Tout=[tf.float32, tf.float32]))
val_ds = tf.data.Dataset.from_tensor_slices(val_files)
val_ds = val_ds.map(lambda inputs: tf.py_func(get_data_validation, [inputs], Tout=[tf.float32, tf.float32]))
val_ds = val_ds.map(lambda x: x.set_shape([None, 544, 960, 3], [None, 544, 960]))
train_ds = configure_for_performance(train_ds)
val_ds = configure_for_performance(val_ds)

当我使用 model.fit() 函数时出现此错误

 val_ds = val_ds.map(lambda x: x.set_shape([None, 544, 960, 3], [None, 544, 960]))
  File "/data2/AIShare/Tools/pytorch-env/lib/python3.5/site-packages/tensorflow/python/data/ops/dataset_ops.py", line 1038, in map
    return MapDataset(self, map_func)
  File "/data2/AIShare/Tools/pytorch-env/lib/python3.5/site-packages/tensorflow/python/data/ops/dataset_ops.py", line 2611, in __init__
    map_func, "Dataset.map()", input_dataset)
  File "/data2/AIShare/Tools/pytorch-env/lib/python3.5/site-packages/tensorflow/python/data/ops/dataset_ops.py", line 1860, in __init__
    self._function.add_to_graph(ops.get_default_graph())
  File "/data2/AIShare/Tools/pytorch-env/lib/python3.5/site-packages/tensorflow/python/framework/function.py", line 479, in add_to_graph
    self._create_definition_if_needed()
  File "/data2/AIShare/Tools/pytorch-env/lib/python3.5/site-packages/tensorflow/python/framework/function.py", line 335, in _create_definition_if_needed
    self._create_definition_if_needed_impl()
  File "/data2/AIShare/Tools/pytorch-env/lib/python3.5/site-packages/tensorflow/python/framework/function.py", line 344, in _create_definition_if_needed_impl
    self._capture_by_value, self._caller_device)
  File "/data2/AIShare/Tools/pytorch-env/lib/python3.5/site-packages/tensorflow/python/framework/function.py", line 864, in func_graph_from_py_func
    outputs = func(*func_graph.inputs)
  File "/data2/AIShare/Tools/pytorch-env/lib/python3.5/site-packages/tensorflow/python/data/ops/dataset_ops.py", line 1794, in tf_data_structured_function_wrapper
    ret = func(*nested_args)
TypeError: <lambda>() takes 1 positional argument but 2 were given

如果我不设置形状

val_ds = val_ds.map(lambda x: x.set_shape([None, 544, 960, 3], [None, 544, 960]))

然后 model.fit() 抱怨未知等级的张量。根据我的研究, tf.py_func() 会导致形状数据丢失,因此需要 set_shape 。

我正在尝试使用 tf.data API 加载 Cityscapes 数据集的图像文件。

谢谢你

4

1 回答 1

0

tf.py_func 不允许您使用 GPU,此处在 tf.py_function tensorflow 文档的文档中进行了解释

也许你应该为你的地图写一个函数,比如

def fct_for_map(img):
  #your code
  return my_tensor

在你尝试之后

train_ds = train_ds.map(fct_for_map)

我希望这会对你有所帮助

于 2021-05-07T13:16:47.313 回答