4

我正在使用 Scibert 预训练模型来获取各种文本的嵌入。代码如下:

from transformers import *

tokenizer = AutoTokenizer.from_pretrained('allenai/scibert_scivocab_uncased', model_max_length=512, truncation=True)
model = AutoModel.from_pretrained('allenai/scibert_scivocab_uncased')

我已将最大长度和截断参数都添加到标记器中,但不幸的是,它们不会截断结果。如果我通过标记器运行更长的文本:

inputs = tokenizer("""long text""")

我收到以下错误:

令牌索引序列长度大于此模型的指定最大序列长度 (605 > 512)。在模型中运行此序列将导致索引错误

现在很明显,由于张量序列太长,我无法在模型中运行它。截断输入以适应最大序列长度 512 的最简单方法是什么?

4

1 回答 1

9

truncation不是类构造函数的参数(类引用),而是__call__方法的参数。因此,您应该使用:

tokenizer = AutoTokenizer.from_pretrained('allenai/scibert_scivocab_uncased', model_max_length=512)

len(tokenizer(text, truncation=True).input_ids)

输出:

512
于 2020-11-27T13:48:13.963 回答