如果有人能推动我采取正确的做法,将不胜感激
我正在尝试构建数据管道来训练对象检测模型。我想为此目的使用 tf.data.Dataset。无论我如何处理这个问题,我都面临着挑战。以下代码是我最接近目标的代码,但它仅适用于批量大小为 1(2 或更多给我一个批处理错误)。一旦批处理中有多个图像,每个图像都有不同数量的边界框,我就会开始收到错误消息:
InvalidArgumentError: Cannot add tensor to the batch: number of elements does not match. Shapes are: [tensor]: [5,4], [batch]: [7,4]
在上述错误中,第一张图像有 7 个边界框,而第二张图像有 5 个。
以下是我构建管道的最新代码。
def data_generator():
''' data generator to be used to generate training data
Returns: a dictionary containing
img_path: string containing the path for the image
all_bbox: nx4 numpy array
cls_lbl: nx1 numpy array
** where n is the number of objects in the image
'''
while True:
for img_path in image_paths:
lbl_file = label_source + '/' + os.path.basename(img_path).replace('.png', '.txt')
lbl_df = pd.read_csv(lbl_file, sep=r'\s', header=None, engine='python')
all_bbox = []
cls_lbl = []
# loading bounding boxes and bbox_classes
for r in lbl_df.iterrows():
if r[1][0] in ['Misc', 'DontCare']:
continue
else:
x_t, y_t = int(r[1][4]), int(r[1][5])
x_b, y_b = int(r[1][6]), int(r[1][7])
all_bbox.append([x_t, y_t, x_b, y_b])
cls_lbl.append(class2lbl[r[1][0]])
yield {"img_path": img_path,
"all_bbox": np.array(all_bbox),
"cls_lbl": np.array(cls_lbl)}
def image_loader(sample):
'''load the image from the file and return a dictionary
'''
raw_img = tf.io.read_file(sample['img_path'])
img = tf.io.decode_png(raw_img)
sample["img"] =img
sample["all_bbox"] = tf.cast(sample['all_bbox'], dtype=tf.float32)
sample["cls_lbl"] = tf.cast(sample['cls_lbl'], dtype=tf.float32)
return sample
train_dataset = tf.data.Dataset.from_generator(data_generator, output_types={"img_path":tf.string,
"all_bbox":tf.float32,
"cls_lbl":tf.float32})
train_dataset = train_dataset.map(image_loader)
train_dataset = train_dataset.batch(2)
val = next(iter(train_dataset))