我正在尝试使用 tf.data API 将图像数据集加载到 GPU 上,因为它们提供了优化的性能。但不幸的是,使用 tf.data.Dataset.map() 函数不会返回与 model.fit() 或 model.fit_generator() 兼容的数据集。假设目录树与 keras ImageDataGenerator 所需的目录树相同。
files = [os.path.join(train_dir, file) for file in os.listdir(train_dir)]
val_files = [os.path.join(val_dir, file) for file in os.listdir(val_dir)]
def get_data_train(file_path: str) -> tuple:
mask_path = tf.strings.regex_replace(file_path, '.jpg$', '.png')
mask_path = tf.strings.regex_replace(mask_path, 'Images', 'Label', replace_global=False)
mask = tf.io.read_file(mask_path)
mask = tf.image.decode_png(mask, channels=1)
mask = tf.image.resize_image_with_pad(mask, target_height=544, target_width=960)
image = tf.io.read_file(file_path)
image = tf.image.decode_jpeg(image, channels=3)
image = tf.image.resize_image_with_pad(image, target_height=544, target_width=960)
if tf.random.uniform(shape=(), minval=0, maxval=1, dtype=tf.float32) > 0.5:
image = tf.image.flip_left_right(image)
mask = tf.image.flip_left_right(mask)
image = tf.image.convert_image_dtype(image, dtype=tf.float32)
mask = tf.image.convert_image_dtype(mask, dtype=tf.float32)
return image / 255., mask[:, :, 0] / 255.
def get_data_validation(file_path: str) -> tuple:
mask_path = tf.strings.regex_replace(file_path, '.jpg$', '.png')
mask_path = tf.strings.regex_replace(mask_path, 'Images', 'Label', replace_global=False)
mask = tf.io.read_file(mask_path)
mask = tf.image.decode_png(mask, channels=1)
mask = tf.image.resize_image_with_pad(mask, target_height=544, target_width=960)
image = tf.io.read_file(file_path)
image = tf.image.decode_jpeg(image, channels=3)
image = tf.image.resize_image_with_pad(image, target_height=544, target_width=960)
image = tf.image.convert_image_dtype(image, dtype=tf.float32)
mask = tf.image.convert_image_dtype(mask, dtype=tf.float32)
return image / 255., mask[:, :, 0] / 255.
def configure_for_performance(dataset: tf.data.Dataset):
dataset = dataset.cache()
dataset = dataset.shuffle(buffer_size=8)
dataset = dataset.batch(args.batch_size)
dataset = dataset.prefetch(buffer_size=8)
return dataset
train_ds = tf.data.Dataset.from_tensor_slices(files)
train_ds = train_ds.map(lambda inputs: tf.py_func(get_data_train, [inputs], Tout=[tf.float32, tf.float32]))
val_ds = tf.data.Dataset.from_tensor_slices(val_files)
val_ds = val_ds.map(lambda inputs: tf.py_func(get_data_validation, [inputs], Tout=[tf.float32, tf.float32]))
val_ds = val_ds.map(lambda x: x.set_shape([None, 544, 960, 3], [None, 544, 960]))
train_ds = configure_for_performance(train_ds)
val_ds = configure_for_performance(val_ds)
当我使用 model.fit() 函数时出现此错误
val_ds = val_ds.map(lambda x: x.set_shape([None, 544, 960, 3], [None, 544, 960]))
File "/data2/AIShare/Tools/pytorch-env/lib/python3.5/site-packages/tensorflow/python/data/ops/dataset_ops.py", line 1038, in map
return MapDataset(self, map_func)
File "/data2/AIShare/Tools/pytorch-env/lib/python3.5/site-packages/tensorflow/python/data/ops/dataset_ops.py", line 2611, in __init__
map_func, "Dataset.map()", input_dataset)
File "/data2/AIShare/Tools/pytorch-env/lib/python3.5/site-packages/tensorflow/python/data/ops/dataset_ops.py", line 1860, in __init__
self._function.add_to_graph(ops.get_default_graph())
File "/data2/AIShare/Tools/pytorch-env/lib/python3.5/site-packages/tensorflow/python/framework/function.py", line 479, in add_to_graph
self._create_definition_if_needed()
File "/data2/AIShare/Tools/pytorch-env/lib/python3.5/site-packages/tensorflow/python/framework/function.py", line 335, in _create_definition_if_needed
self._create_definition_if_needed_impl()
File "/data2/AIShare/Tools/pytorch-env/lib/python3.5/site-packages/tensorflow/python/framework/function.py", line 344, in _create_definition_if_needed_impl
self._capture_by_value, self._caller_device)
File "/data2/AIShare/Tools/pytorch-env/lib/python3.5/site-packages/tensorflow/python/framework/function.py", line 864, in func_graph_from_py_func
outputs = func(*func_graph.inputs)
File "/data2/AIShare/Tools/pytorch-env/lib/python3.5/site-packages/tensorflow/python/data/ops/dataset_ops.py", line 1794, in tf_data_structured_function_wrapper
ret = func(*nested_args)
TypeError: <lambda>() takes 1 positional argument but 2 were given
如果我不设置形状
val_ds = val_ds.map(lambda x: x.set_shape([None, 544, 960, 3], [None, 544, 960]))
然后 model.fit() 抱怨未知等级的张量。根据我的研究, tf.py_func() 会导致形状数据丢失,因此需要 set_shape 。
我正在尝试使用 tf.data API 加载 Cityscapes 数据集的图像文件。
谢谢你