0

我一直在尝试将我构建的生成器转换为 tf.data.dataset。我已经走了很远,现在我有一些像这样简单的东西

def parse_image(filename):
    file = tf.io.read_file(filename) # this will work only with filename as tensor
    image = tf.image.decode_image(file)
    return image

def transform_img(img):
  img = parse_image(img).numpy()
  img = transforms_train(image = img)["image"]
  return img

当我在文件名本身上调用它时,transform img 会按预期工作。喜欢:

plt.imshow(transform_img(array_of_filenames[0]))

但是当我将它映射到数据集时

dataset = tf.data.Dataset.from_tensor_slices(array_of_filenames)
dataset = dataset.map(transform_img)

我得到标题中的错误。

我又在做傻事了不是吗?感谢您的帮助!

4

1 回答 1

1

不能在 tensorflow 数据集的 map 函数中使用 numpy。否则,您需要将函数包装在tf.py_functionor中tf.numpy_function。所以它应该如下所示:

dataset = dataset.map(lambda: item: tf.py_function(transform_img, [item], [tf.float32]))

的第一个参数py_function是你想要的预处理函数,第二个参数是传递给函数的参数。最后一个参数是预处理函数返回的数据类型。(同样适用于tf.numpy_function

我不记得在文档中阅读过这个,但在教程中,你可以在这里找到它。

于 2021-05-20T12:37:22.507 回答