0

我正在阅读有关使用 AllenNlp 框架的基于字符的神经网络的教程,目标是构建一个可以完成句子的模型。在我想训练我的模型之后,有一个实例构建步骤。我有下面的代码,我无法理解转发功能的作用,任何人都可以帮忙吗?有人可以举个例子吗

class RNNLanguageModel(Model):
def __init__(self,
             embedder: TextFieldEmbedder,
             hidden_size: int,
             max_len: int,
             vocab: Vocabulary) -> None:
    super().__init__(vocab)

    self.embedder = embedder

    # initialize a Seq2Seq encoder, LSTM
    self.rnn = PytorchSeq2SeqWrapper(
        torch.nn.LSTM(EMBEDDING_SIZE, HIDDEN_SIZE, batch_first=True))

    self.hidden2out = torch.nn.Linear(in_features=self.rnn.get_output_dim(), out_features=vocab.get_vocab_size('tokens'))
    self.hidden_size = hidden_size
    self.max_len = max_len

def forward(self, input_tokens, output_tokens):
    '''
    This is the main process of the Model where the actual computation happens. 
    Each Instance is fed to the forward method. 
    It takes dicts of tensors as input, with same keys as the fields in your Instance (input_tokens, output_tokens)
    It outputs the results of predicted tokens and the evaluation metrics as a dictionary. 
    '''

    mask = get_text_field_mask(input_tokens)
    embeddings = self.embedder(input_tokens)
    rnn_hidden = self.rnn(embeddings, mask)
    out_logits = self.hidden2out(rnn_hidden)
    loss = sequence_cross_entropy_with_logits(out_logits, output_tokens['tokens'], mask)

    return {'loss': loss}
4

1 回答 1

0

forward()方法是我们实现模型“前向传递”的地方。这决定了输入(您的数据)如何流经您的模型以产生输出和损失值。

forward()方法需要由继承自 PyTorch 的任何类实现Module,例如 AllenNLP 的Model类。

AllenNLP 最终只是 PyTorch 的一个更高级别的包装器,所以如果您对此感到困惑,我建议您首先熟悉 PyTorch:https ://pytorch.org/tutorials/

于 2020-10-14T16:21:13.533 回答