1

我在尝试使用albumentations 库扩充图像时遇到此错误,该库使用tf.numpy_function 包装python 函数以通过此链接在tensorflow 中进行扩充:https ://albumentations.ai/docs/examples/tensorflow-example/

我已经使用 tensorflow 数据集 API 加载了我的图像数据集和目标标签。编码 :

img_paths = df['image_path'].values
target = df['target_label'].values

path_lis = tf.data.Dataset.from_tensor_slices(img_paths)
target_lis = tf.data.Dataset.from_tensor_slices(target)
list_ds = tf.data.Dataset.zip((path_lis, target_lis))

image_count = len(df)
val_size = int(image_count * 0.3)
train = list_ds.skip(val_size)
val = list_ds.take(val_size)


def process_path(file_path, target):

  # load the raw data from the file as a string
  img = tf.io.read_file(file_path)
  img = tf.image.decode_jpeg(img, channels=3)

  return img, target

# Set `num_parallel_calls` so multiple images are loaded/processed in parallel.
train_data = train.map(process_path, num_parallel_calls=AUTOTUNE)
val_data = val.map(process_path, num_parallel_calls=AUTOTUNE)

# Augmentation using albumentations library
transforms = A.Compose([
            A.Rotate(limit=40),
            A.RandomBrightness(limit=0.1),
            A.RandomContrast(limit=0.9, p=1),
            A.HorizontalFlip(),
            A.Resize(224, 224)
            ])

def aug_fn(image):

    data = {"image": image}
    aug_data = transforms(**data)
    aug_img = aug_data["image"]
    #target = aug_data["keypoints"][0]
    aug_img = tf.cast(aug_img/255.0, tf.float32)
    #aug_img = tf.image.resize(aug_img, size=[224, 224])

    return aug_img

def process_aug(img, label):

    aug_img = tf.numpy_function(func=aug_fn, inp=[img], Tout=[tf.float32])
    return aug_img, label

# create dataset
train_ds = train_data.map(process_aug, num_parallel_calls=AUTOTUNE).prefetch(AUTOTUNE)
val_ds = val_data.map(process_aug, num_parallel_calls=AUTOTUNE).prefetch(AUTOTUNE)

def set_shapes(img, label):

    img.set_shape([224, 224, 3])
    label.set_shape([])

    return img, label

train_ds = train_ds.map(set_shapes, num_parallel_calls=AUTOTUNE).batch(8).prefetch(AUTOTUNE)
val_ds = val_ds.map(set_shapes, num_parallel_calls=AUTOTUNE).batch(8).prefetch(AUTOTUNE)


def view_image(ds):

    image, label = next(iter(ds)) # extract 1 batch from the dataset
    image = image.numpy()
    label = label.numpy()

    fig = plt.figure(figsize=(22, 22))
    for i in range(20):
        ax = fig.add_subplot(4, 5, i+1, xticks=[], yticks=[])
        ax.imshow(image[i])
        ax.set_title(f"Label: {label[i]}")

view_image(train_ds)

完整的错误信息:

Traceback (most recent call last):
  File "C:\Users\Arun\Anaconda3\lib\site-packages\tensorflow\python\eager\context.py", line 2102, in execution_mode
    yield
  File "C:\Users\Arun\Anaconda3\lib\site-packages\tensorflow\python\data\ops\iterator_ops.py", line 758, in _next_internal
    output_shapes=self._flat_output_shapes)
  File "C:\Users\Arun\Anaconda3\lib\site-packages\tensorflow\python\ops\gen_dataset_ops.py", line 2610, in iterator_get_next
    _ops.raise_from_not_ok_status(e, name)
  File "C:\Users\Arun\Anaconda3\lib\site-packages\tensorflow\python\framework\ops.py", line 6843, in raise_from_not_ok_status
    six.raise_from(core._status_to_exception(e.code, message), None)
  File "<string>", line 3, in raise_from
tensorflow.python.framework.errors_impl.InvalidArgumentError: Incompatible shapes at component 0: expected [?,224,224,3] but got [4,1,224,224,3]. [Op:IteratorGetNext]

During handling of the above exception, another exception occurred:
Traceback (most recent call last):
  File "C:\Users\Arun\Anaconda3\lib\site-packages\IPython\core\interactiveshell.py", line 2963, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "<ipython-input-20-23a37450bee7>", line 13, in <module>
    view_image(train_ds)
  File "<ipython-input-20-23a37450bee7>", line 3, in view_image
    image, label = next(iter(ds)) # extract 1 batch from the dataset
  File "C:\Users\Arun\Anaconda3\lib\site-packages\tensorflow\python\data\ops\iterator_ops.py", line 736, in __next__
    return self.next()
  File "C:\Users\Arun\Anaconda3\lib\site-packages\tensorflow\python\data\ops\iterator_ops.py", line 772, in next
    return self._next_internal()
  File "C:\Users\Arun\Anaconda3\lib\site-packages\tensorflow\python\data\ops\iterator_ops.py", line 764, in _next_internal
    return structure.from_compatible_tensor_list(self._element_spec, ret)
  File "C:\Users\Arun\Anaconda3\lib\contextlib.py", line 99, in __exit__
    self.gen.throw(type, value, traceback)
  File "C:\Users\Arun\Anaconda3\lib\site-packages\tensorflow\python\eager\context.py", line 2105, in execution_mode
    executor_new.wait()
  File "C:\Users\Arun\Anaconda3\lib\site-packages\tensorflow\python\eager\executor.py", line 67, in wait
    pywrap_tfe.TFE_ExecutorWaitForAllPendingNodes(self._handle)

tensorflow.python.framework.errors_impl.InvalidArgumentError: Incompatible shapes at component 0: expected [?,224,224,3] but got [8,1,224,224,3].

有人至少可以告诉我为什么会发生这个错误吗?提前致谢!

4

1 回答 1

0

img_shape 应该是 (224, 224, 3) 而不是 [224, 224, 3]

例如:

def set_shapes(img, label, img_shape=(120,120,3)):
    img.set_shape(img_shape)
    label.set_shape([])
    return img, label
于 2021-02-03T06:34:57.993 回答