我有一个图像数据集,我通过tf.data.Dataset.list_files()
.
在我的.map()
函数中,我读取和解码图像,如下所示:
def map_function(filepath):
image = tf.io.read_file(filename=filepath)
image = tf.image.decode_jpeg(image, channels=3)
image = tf.image.convert_image_dtype(image, tf.float32)
image = tf.image.resize(image, [IMAGE_WIDTH, IMAGE_HEIGHT])
return image
如果我使用(下面的工作)
dataset = tf.data.Dataset.list_files(file_pattern=...)
dataset = dataset.map(map_function)
for image in dataset.as_numpy_iterator():
#Correctly outputs the numpy array, no error is displayed/encountered
print(image)
但是,如果我使用(下面会抛出错误):
dataset = tf.data.Dataset.list_files(file_pattern=...)
dataset = dataset.batch(32).map(map_function)
for image in dataset.as_numpy_iterator():
#Error is displayed
print(image)
ValueError:形状必须为 0 级,但对于具有输入形状的“ReadFile”(操作:“ReadFile”)为 1 级:[?]。
现在,根据这个:https://www.tensorflow.org/guide/data_performance#vectorizing_mapping,代码不应该失败并且预处理步骤应该被优化(批处理与一次性处理)。
我的代码中的错误在哪里?
***如果我使用map().batch()
它工作正常