import numpy as np
from transformers import GPTNeoForCausalLM, GPT2Tokenizer
import coremltools as ct
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
sentence_fragment = "The Oceans are"
class NEO(torch.nn.Module):
def __init__(self, model):
super(NEO, self).__init__()
self.next_token_predictor = model
def forward(self, x):
sentence = x
predictions, _ = self.next_token_predictor(sentence)
token = torch.argmax(predictions[-1, :], dim=0, keepdim=True)
sentence = torch.cat((sentence, token), 0)
return sentence
token_predictor = GPTNeoForCausalLM.from_pretrained("EleutherAI/gpt-neo-125M", torchscript=True).eval()
context = torch.tensor(tokenizer.encode(sentence_fragment))
random_tokens = torch.randint(10000, (5,))
traced_token_predictor = torch.jit.trace(token_predictor, random_tokens)
model = NEO(model=traced_token_predictor)
scripted_model = torch.jit.script(model)
# Custom model
sentence_fragment = "The Oceans are"
for i in range(10):
context = torch.tensor(tokenizer.encode(sentence_fragment))
torch_out = scripted_model(context)
sentence_fragment = tokenizer.decode(torch_out)
print("Custom model: {}".format(sentence_fragment))
# Stock model
model = GPTNeoForCausalLM.from_pretrained("EleutherAI/gpt-neo-125M", torchscript=True).eval()
sentence_fragment = "The Oceans are"
input_ids = tokenizer(sentence_fragment, return_tensors="pt").input_ids
gen_tokens = model.generate(input_ids, do_sample=True, max_length=20)
gen_text = tokenizer.batch_decode(gen_tokens)[0]
print("Stock model: "+gen_text)
运行 1
输出:
Custom model: The Oceans are the most important source of water for the entire world
Stock model: The Oceans are on the rise. The American Southwest is thriving, but the southern United States still
运行 2
输出:
Custom model: The Oceans are the most important source of water for the entire world.
Stock model: The Oceans are the land of man
This is a short video of the Australian government
自定义模型始终返回相同的输出。然而,随着do_sampling = True
股票model.generate
在每次调用中返回不同的结果。我花了很多时间弄清楚 do_sampling 如何对变形金刚起作用,所以我需要你们的帮助,不胜感激。
如何编写自定义模型以在每次调用时产生不同的结果?
谢谢!