0

我尝试了 seq2seq 此处提供的 seq2seq pytorch实现。在分析评估(evaluate.py)代码后,需要较长时间的代码是decode_minibatch方法

def decode_minibatch(
    config,
    model,
    input_lines_src,
    input_lines_trg,
    output_lines_trg_gold
):
    """Decode a minibatch."""
    for i in xrange(config['data']['max_trg_length']):

        decoder_logit = model(input_lines_src, input_lines_trg)
        word_probs = model.decode(decoder_logit)
        decoder_argmax = word_probs.data.cpu().numpy().argmax(axis=-1)
        next_preds = Variable(
            torch.from_numpy(decoder_argmax[:, -1])
        ).cuda()

        input_lines_trg = torch.cat(
            (input_lines_trg, next_preds.unsqueeze(1)),
            1
        )

return input_lines_trg

在 GPU 上训练模型并在 CPU 模式下加载模型进行推理。但不幸的是,每句话似乎都需要大约 10 秒。pytorch 预计会出现缓慢的预测吗?

任何修复,加快速度的建议将不胜感激。谢谢。

4

0 回答 0