1

当我尝试在tf.data.Dataset管道中裁剪一批图像时出现以下错误:

InvalidArgumentError:输入形状轴 0 必须等于 4,得到形状 [5] [[{{node crop_to_bounding_box/unstack}}]] [Op:IteratorGetNext]

def crop(img_batch, label_batch):
    #cropped_image = img_batch
    cropped_image = tf.image.crop_to_bounding_box(img_batch, 0, 0, 100, 100)
    return cropped_image, label_batch


train_dataset_cropped = train_dataset.map(crop)

但是当我尝试运行以下 for 循环时,我得到了提到的错误:

for img_batch, label_batch in train_dataset_cropped:
    print(type(img_batch), img_batch.shape, label_batch.shape)

请注意,管道在没有函数tf.image.crop_to_bounding_box内部的情况下crop工作(直接使用cropped_image = img_batch)。

您知道如何在 tf.data.Dataset 管道中正确裁剪一批图像吗?

4

1 回答 1

0

我没有找到任何文档,但我认为您不能从tf.image将在tf.data.Dataset.map. 对于您的问题,一个简单的解决方法是:

def crop(img_batch, label_batch):
    cropped_image = img_batch[:, :100, :100] # if your dataset is already batched
    # cropped_image = img_batch[:100, :100] # otherwise
    return cropped_image, label_batch
于 2021-12-06T15:19:33.203 回答