对于一个项目,我使用 tf.data.Dataset 来编写输入管道。输入是图像 RGB 标签是图像中用于生成热图的对象的 2D 坐标列表
这是重现问题的MWE。
def encode_images(image, label):
"""
Parameters
----------
image
label
Returns
-------
"""
# load image
# here the normal code
# img_contents = tf.io.read_file(image)
# # decode the image
# img = tf.image.decode_jpeg(img_contents, channels=3)
# img = tf.image.resize(img, (256, 256))
# img = tf.cast(img, tf.float32)
# this is just for testing
image = tf.random.uniform(
(256, 256, 3), minval=0, maxval=255, dtype=tf.dtypes.float32, seed=None, name=None
)
return image, label
def generate_heatmap(image, label):
"""
Parameters
----------
image
label
Returns
-------
"""
start = 0.5
sigma=3
img_shape = (image.shape[0] , image.shape[1] )
density_map = np.zeros(img_shape, dtype=np.float32)
for center_x, center_y in label[0]:
for v_y in range(img_shape[0]):
for v_x in range(img_shape[1]):
x = start + v_x
y = start + v_y
d2 = (x - center_x) * (x - center_x) + (y - center_y) * (y - center_y)
exp = d2 / (2.0 * sigma**2)
if exp > math.log(100):
continue
density_map[v_y, v_x] = math.exp(-exp)
return density_map
X = ["img1.png", "img2.png", "img3.png", "img4.png", "img5.png"]
y = [[[2, 3], [100, 120], [100, 120]],
[[2, 3], [100, 120], [100, 120], [2, 1]],
[[2, 3], [100, 120], [100, 120], [10, 10], [11, 12]],
[[2, 3], [100, 120], [100, 120], [10, 10], [11, 12], [10, 2]],
[[2, 3], [100, 120], [100, 120]]
]
dataset = tf.data.Dataset.from_tensor_slices((X, tf.ragged.constant(y)))
dataset = dataset.map(encode_images, num_parallel_calls=8)
dataset = dataset.map(generate_heatmap, num_parallel_calls=8)
dataset = dataset.batch(1, drop_remainder=False)
问题是在generate_heatmap()
函数中,我使用 numpy 数组通过索引修改元素,这在 tensorflow 中是不可能的。我尝试迭代标签张量,这在 tensorflow 中是不可能的。另一件事是急切模式似乎没有启用 tf.data.Dataset
!有什么建议可以解决这个问题!我认为在 pytorch 中这样的代码可以快速完成而不会受苦:)!