我在尝试使用albumentations 库扩充图像时遇到此错误,该库使用tf.numpy_function 包装python 函数以通过此链接在tensorflow 中进行扩充:https ://albumentations.ai/docs/examples/tensorflow-example/
我已经使用 tensorflow 数据集 API 加载了我的图像数据集和目标标签。编码 :
img_paths = df['image_path'].values
target = df['target_label'].values
path_lis = tf.data.Dataset.from_tensor_slices(img_paths)
target_lis = tf.data.Dataset.from_tensor_slices(target)
list_ds = tf.data.Dataset.zip((path_lis, target_lis))
image_count = len(df)
val_size = int(image_count * 0.3)
train = list_ds.skip(val_size)
val = list_ds.take(val_size)
def process_path(file_path, target):
# load the raw data from the file as a string
img = tf.io.read_file(file_path)
img = tf.image.decode_jpeg(img, channels=3)
return img, target
# Set `num_parallel_calls` so multiple images are loaded/processed in parallel.
train_data = train.map(process_path, num_parallel_calls=AUTOTUNE)
val_data = val.map(process_path, num_parallel_calls=AUTOTUNE)
# Augmentation using albumentations library
transforms = A.Compose([
A.Rotate(limit=40),
A.RandomBrightness(limit=0.1),
A.RandomContrast(limit=0.9, p=1),
A.HorizontalFlip(),
A.Resize(224, 224)
])
def aug_fn(image):
data = {"image": image}
aug_data = transforms(**data)
aug_img = aug_data["image"]
#target = aug_data["keypoints"][0]
aug_img = tf.cast(aug_img/255.0, tf.float32)
#aug_img = tf.image.resize(aug_img, size=[224, 224])
return aug_img
def process_aug(img, label):
aug_img = tf.numpy_function(func=aug_fn, inp=[img], Tout=[tf.float32])
return aug_img, label
# create dataset
train_ds = train_data.map(process_aug, num_parallel_calls=AUTOTUNE).prefetch(AUTOTUNE)
val_ds = val_data.map(process_aug, num_parallel_calls=AUTOTUNE).prefetch(AUTOTUNE)
def set_shapes(img, label):
img.set_shape([224, 224, 3])
label.set_shape([])
return img, label
train_ds = train_ds.map(set_shapes, num_parallel_calls=AUTOTUNE).batch(8).prefetch(AUTOTUNE)
val_ds = val_ds.map(set_shapes, num_parallel_calls=AUTOTUNE).batch(8).prefetch(AUTOTUNE)
def view_image(ds):
image, label = next(iter(ds)) # extract 1 batch from the dataset
image = image.numpy()
label = label.numpy()
fig = plt.figure(figsize=(22, 22))
for i in range(20):
ax = fig.add_subplot(4, 5, i+1, xticks=[], yticks=[])
ax.imshow(image[i])
ax.set_title(f"Label: {label[i]}")
view_image(train_ds)
完整的错误信息:
Traceback (most recent call last):
File "C:\Users\Arun\Anaconda3\lib\site-packages\tensorflow\python\eager\context.py", line 2102, in execution_mode
yield
File "C:\Users\Arun\Anaconda3\lib\site-packages\tensorflow\python\data\ops\iterator_ops.py", line 758, in _next_internal
output_shapes=self._flat_output_shapes)
File "C:\Users\Arun\Anaconda3\lib\site-packages\tensorflow\python\ops\gen_dataset_ops.py", line 2610, in iterator_get_next
_ops.raise_from_not_ok_status(e, name)
File "C:\Users\Arun\Anaconda3\lib\site-packages\tensorflow\python\framework\ops.py", line 6843, in raise_from_not_ok_status
six.raise_from(core._status_to_exception(e.code, message), None)
File "<string>", line 3, in raise_from
tensorflow.python.framework.errors_impl.InvalidArgumentError: Incompatible shapes at component 0: expected [?,224,224,3] but got [4,1,224,224,3]. [Op:IteratorGetNext]
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "C:\Users\Arun\Anaconda3\lib\site-packages\IPython\core\interactiveshell.py", line 2963, in run_code
exec(code_obj, self.user_global_ns, self.user_ns)
File "<ipython-input-20-23a37450bee7>", line 13, in <module>
view_image(train_ds)
File "<ipython-input-20-23a37450bee7>", line 3, in view_image
image, label = next(iter(ds)) # extract 1 batch from the dataset
File "C:\Users\Arun\Anaconda3\lib\site-packages\tensorflow\python\data\ops\iterator_ops.py", line 736, in __next__
return self.next()
File "C:\Users\Arun\Anaconda3\lib\site-packages\tensorflow\python\data\ops\iterator_ops.py", line 772, in next
return self._next_internal()
File "C:\Users\Arun\Anaconda3\lib\site-packages\tensorflow\python\data\ops\iterator_ops.py", line 764, in _next_internal
return structure.from_compatible_tensor_list(self._element_spec, ret)
File "C:\Users\Arun\Anaconda3\lib\contextlib.py", line 99, in __exit__
self.gen.throw(type, value, traceback)
File "C:\Users\Arun\Anaconda3\lib\site-packages\tensorflow\python\eager\context.py", line 2105, in execution_mode
executor_new.wait()
File "C:\Users\Arun\Anaconda3\lib\site-packages\tensorflow\python\eager\executor.py", line 67, in wait
pywrap_tfe.TFE_ExecutorWaitForAllPendingNodes(self._handle)
tensorflow.python.framework.errors_impl.InvalidArgumentError: Incompatible shapes at component 0: expected [?,224,224,3] but got [8,1,224,224,3].
有人至少可以告诉我为什么会发生这个错误吗?提前致谢!