0

对于一个项目,我使用 tf.data.Dataset 来编写输入管道。输入是图像 RGB 标签是图像中用于生成热图的对象的 2D 坐标列表

这是重现问题的MWE。

 def encode_images(image, label):
        """

        Parameters
        ----------
        image
        label

        Returns
        -------

        """
        # load image
        # here the normal code
        # img_contents = tf.io.read_file(image)
        # # decode the image
        # img = tf.image.decode_jpeg(img_contents, channels=3)
        # img = tf.image.resize(img, (256, 256))
        # img = tf.cast(img, tf.float32)

        # this is just for testing
        image = tf.random.uniform(
            (256, 256, 3), minval=0, maxval=255, dtype=tf.dtypes.float32, seed=None, name=None
        )
        return image, label

    def generate_heatmap(image, label):
        """

        Parameters
        ----------
        image
        label

        Returns
        -------

        """

        start = 0.5
        sigma=3
        img_shape = (image.shape[0] , image.shape[1] )
        density_map = np.zeros(img_shape, dtype=np.float32)
        for center_x, center_y in label[0]:
            for v_y in range(img_shape[0]):
                for v_x in range(img_shape[1]):
                    x = start + v_x
                    y = start + v_y
                    d2 = (x - center_x) * (x - center_x) + (y - center_y) * (y - center_y)
                    exp = d2 / (2.0 * sigma**2)
                    if exp > math.log(100):
                        continue
                    density_map[v_y, v_x] = math.exp(-exp)
        return density_map


    X = ["img1.png", "img2.png", "img3.png", "img4.png", "img5.png"]
    y = [[[2, 3], [100, 120], [100, 120]],
         [[2, 3], [100, 120], [100, 120], [2, 1]],
         [[2, 3], [100, 120], [100, 120], [10, 10], [11, 12]],
         [[2, 3], [100, 120], [100, 120], [10, 10], [11, 12], [10, 2]],
         [[2, 3], [100, 120], [100, 120]]
         ]
    dataset = tf.data.Dataset.from_tensor_slices((X, tf.ragged.constant(y)))
    dataset = dataset.map(encode_images, num_parallel_calls=8)
    dataset = dataset.map(generate_heatmap, num_parallel_calls=8)
    dataset = dataset.batch(1, drop_remainder=False)

问题是在generate_heatmap()函数中,我使用 numpy 数组通过索引修改元素,这在 tensorflow 中是不可能的。我尝试迭代标签张量,这在 tensorflow 中是不可能的。另一件事是急切模式似乎没有启用 tf.data.Dataset!有什么建议可以解决这个问题!我认为在 pytorch 中这样的代码可以快速完成而不会受苦:)!

4

0 回答 0