我想使用 tf.data API 创建图像序列样本。但到目前为止,似乎没有简单的方法可以连接多个图像以形成单个样本。我尝试使用 dataset.window 函数,它将我的图像正确分组。但我不知道如何连接它们。
import tensorflow as tf
from glob import glob
IMG_WIDTH = 256
IMG_HEIGHT = 256
def load_and_process_image(path):
img = tf.io.read_file(path)
img = tf.image.decode_jpeg(img, channels=3)
img = tf.image.resize(img, [IMG_WIDTH, IMG_HEIGHT])
img = tf.reshape(img, shape=(IMG_WIDTH, IMG_HEIGHT, 1, 3))
return img
def create_dataset(files, time_distance=8, frame_step=1):
dataset = tf.data.Dataset.from_tensor_slices(files)
dataset = dataset.map(load_and_process_image)
dataset = dataset.window(time_distance, 1, frame_step, True)
# TODO: Concatenate elements from dataset.window
return dataset
files = sorted(glob('some/path/*.jpg'))
images = create_dataset(images)
我知道我可以将图像序列保存为 TFRecords,但这会使我的数据管道更加不灵活,并且会消耗大量内存。
我的输入批次应具有 N x W x H x T x C 的形式(N:样本数 W:图像宽度 H:图像高度 T:图像序列长度 C:图像通道)。