0

我正在尝试按照教程(参考代码)为令牌分类模型创建自定义数据集。作为这一切的新手,我将感谢您帮助我走上正轨。

到目前为止,我已经设法创建了一个看起来像这样的训练/测试数据框。

train_df.head(5)
words   labels  boxes
0   [Page, 1, of, 2, BILL, OF, LADING, -, SHORT, F...   [O, O, O, O, O, O, O, O, O, O, O, O, O, O, O, ...   [[880, 33, 922, 46], [922, 33, 932, 46], [932,...
1   [Pickup, Date:, 09/06/2021, BILL, OF, LADING, ...   [O, O, O, O, O, O, O, O, O, O, O, O, O, O, O, ...   [[37, 26, 82, 38], [82, 26, 119, 38], [119, 26...
2   [Pickup, Date:, 09/06/2021, BILL, OF, LADING, ...   [O, O, O, O, O, O, O, O, O, O, O, O, O, O, O, ...   [[40, 27, 85, 38], [85, 27, 122, 38], [122, 27...
3   [BILL, OF, LADING, Page, 1, Date:, 9/16/2021, ...   [O, O, O, O, O, O, O, O, O, O, O, O, O, B-SCit...   [[398, 30, 462, 45], [462, 30, 494, 45], [494,...
4   [Pickup, Date:, 09/06/2021, BILL, OF, LADING, ...   [O, O, O, O, O, O, O, O, O, O, O, O, O, O, O, ...   [[39, 25, 84, 37], [84, 25, 122, 37], [122, 25...

此外,我有这个“preprocess_data()”函数,我不知道如何在 df 上应用它并在其上创建自定义数据集。

# we need to define custom features
features = Features({
    'image': Array3D(dtype="int64", shape=(3, 224, 224)),
    'input_ids': Sequence(feature=Value(dtype='int64')),
    'attention_mask': Sequence(Value(dtype='int64')),
    'token_type_ids': Sequence(Value(dtype='int64')),
    'bbox': Array2D(dtype="int64", shape=(512, 4)),
    'labels': Sequence(ClassLabel(names=labels)),
})

def preprocess_data(examples):
    images = [Image.open(path).convert("RGB") for path in examples['image_path']]
    words = examples['words']
    boxes = examples['boxes']
    word_labels=[label2id[label] for label in examples['labels']]
    encoded_inputs = processor(images, words, boxes=boxes, word_labels=word_labels, return_tensors='pt',
                             padding="max_length", truncation=True)

    return encoded_inputs

我将此代码应用于 df 的单个记录,它给出了预期的结果(根据文档)。

encoding = processor(Image.open(image).convert("RGB") , 
                     sentences_list[0], 
                     boxes=bbox_list[0], 
                     word_labels=[label2id[label] for label in labels_list[0]], # numeric expected. string not allowed.
                     return_tensors="pt", padding='max_length', truncation=True)
4

0 回答 0