我使用库训练了一个 T5 变压器simpletransformers
。
这是获取预测的代码:
pred_values = model.predict(input_values)
但是,它只返回顶部或贪婪预测,我怎样才能获得 10 个顶部结果?
我使用库训练了一个 T5 变压器simpletransformers
。
这是获取预测的代码:
pred_values = model.predict(input_values)
但是,它只返回顶部或贪婪预测,我怎样才能获得 10 个顶部结果?
必需的参数是num_return_sequences
,它显示要生成的样本数。但是,如果要使用波束搜索算法,还应该为波束搜索设置一个数字。
model_args = T5Args()
model_args.num_beams = 5
model_args.num_return_sequences = 2
或者,您可以使用top_k
或top_p
在顶级样本中生成和选择,在这些情况下,您必须设置do_sample
为True
。有关参数的更多信息,请参阅[1]和[2],这是一个详细的解释。
model_args = T5Args()
model_args.do_sample = True
model_args.top_p = 0.9
model_args.num_return_sequences = 2