我尝试了 seq2seq 此处提供的 seq2seq pytorch实现。在分析评估(evaluate.py)代码后,需要较长时间的代码是decode_minibatch方法
def decode_minibatch(
config,
model,
input_lines_src,
input_lines_trg,
output_lines_trg_gold
):
"""Decode a minibatch."""
for i in xrange(config['data']['max_trg_length']):
decoder_logit = model(input_lines_src, input_lines_trg)
word_probs = model.decode(decoder_logit)
decoder_argmax = word_probs.data.cpu().numpy().argmax(axis=-1)
next_preds = Variable(
torch.from_numpy(decoder_argmax[:, -1])
).cuda()
input_lines_trg = torch.cat(
(input_lines_trg, next_preds.unsqueeze(1)),
1
)
return input_lines_trg
在 GPU 上训练模型并在 CPU 模式下加载模型进行推理。但不幸的是,每句话似乎都需要大约 10 秒。pytorch 预计会出现缓慢的预测吗?
任何修复,加快速度的建议将不胜感激。谢谢。