1

我正在使用 GPT-Neo 模型transformers来生成文本。因为我使用的提示以 开头'{',所以我想在'}'生成配对后停止该句子。我发现StoppingCriteria源代码中有一个方法,但没有进一步说明如何使用它。有没有人找到一种方法来尽早停止模型生成?谢谢!

这是我尝试过的:

from transformers import StoppingCriteria, AutoModelForCausalLM, AutoTokenizer
model_name = 'gpt2'
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name, pad_token_id=tokenizer.eos_token_id, torch_dtype=dtype).eval()

class KeywordsStoppingCriteria(StoppingCriteria):
    def __init__(self, keywords_ids:list):
        self.keywords = keywords_ids

    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
        if input_ids in self.keywords:
            return True
        return False

stop_words = ['}', ' }', '\n']
stop_ids = [tokenizer.encode(w) for w in stop_words]
stop_ids.append(tokenizer.eos_token_id)
stop_criteria = KeywordsStoppingCriteria(stop_ids)

model.generate(
    text_inputs='some text:{', 
    StoppingCriteria=stop_criteria
)

4

0 回答 0