0

我使用库训练了一个 T5 变压器simpletransformers

这是获取预测的代码:

pred_values = model.predict(input_values)

但是,它只返回顶部或贪婪预测,我怎样才能获得 10 个顶部结果?

4

1 回答 1

3

必需的参数是num_return_sequences,它显示要生成的样本数。但是,如果要使用波束搜索算法,还应该为波束搜索设置一个数字。

model_args = T5Args()
model_args.num_beams = 5
model_args.num_return_sequences = 2

或者,您可以使用top_ktop_p在顶级样本中生成和选择,在这些情况下,您必须设置do_sampleTrue。有关参数的更多信息,请参阅[1]和[2],这是一个详细的解释。

model_args = T5Args()
model_args.do_sample = True
model_args.top_p = 0.9
model_args.num_return_sequences = 2

[1] https://simpletransformers.ai/docs/t5-model/

[2] https://huggingface.co/blog/how-to-generate

于 2021-03-12T17:46:55.013 回答