当我尝试在tf.data.Dataset
管道中裁剪一批图像时出现以下错误:
InvalidArgumentError:输入形状轴 0 必须等于 4,得到形状 [5] [[{{node crop_to_bounding_box/unstack}}]] [Op:IteratorGetNext]
def crop(img_batch, label_batch):
#cropped_image = img_batch
cropped_image = tf.image.crop_to_bounding_box(img_batch, 0, 0, 100, 100)
return cropped_image, label_batch
train_dataset_cropped = train_dataset.map(crop)
但是当我尝试运行以下 for 循环时,我得到了提到的错误:
for img_batch, label_batch in train_dataset_cropped:
print(type(img_batch), img_batch.shape, label_batch.shape)
请注意,管道在没有函数tf.image.crop_to_bounding_box
内部的情况下crop
工作(直接使用cropped_image = img_batch
)。
您知道如何在 tf.data.Dataset 管道中正确裁剪一批图像吗?