我正在尝试优化 GPT2 的推理时间。在 Google Colab 上调用脚本后生成样本的当前时间为 55 秒。我输入了时间戳以尝试找出瓶颈在哪里。这是代码:
for _ in range(nsamples // batch_size):
out = sess.run(output, feed_dict={
context: [context_tokens for _ in range(batch_size)]
})[:, len(context_tokens):]
for i in range(batch_size):
generated += 1
text = enc.decode(out[i])
print("=" * 40 + " SAMPLE " + str(generated) + " " + "=" * 40)
print(text)
print("=" * 80)
线
out = sess.run(output, feed_dict={
context: [context_tokens for _ in range(batch_size)]
})[:, len(context_tokens):]
是复杂性所在。有没有人有办法改进这段代码?太感谢了!