我一直在尝试将我构建的生成器转换为 tf.data.dataset。我已经走了很远,现在我有一些像这样简单的东西
def parse_image(filename):
file = tf.io.read_file(filename) # this will work only with filename as tensor
image = tf.image.decode_image(file)
return image
def transform_img(img):
img = parse_image(img).numpy()
img = transforms_train(image = img)["image"]
return img
当我在文件名本身上调用它时,transform img 会按预期工作。喜欢:
plt.imshow(transform_img(array_of_filenames[0]))
但是当我将它映射到数据集时
dataset = tf.data.Dataset.from_tensor_slices(array_of_filenames)
dataset = dataset.map(transform_img)
我得到标题中的错误。
我又在做傻事了不是吗?感谢您的帮助!