1
def __init__(self):
    super().__init__()
    self.lstm = nn.LSTM(input_dim,
                        hidden_dim, 
                       num_layers=num_layers, 
                       bidirectional=bidirectional, 
                       dropout=dropout,
                       batch_first=True)
    self.fc = nn.Linear(hidden_dim * 2, num_classes)    
def attention_net(self, lstm_output, final_state):
    hidden = final_state.unsqueeze(2)
    attn_weights = torch.bmm(lstm_output, hidden).squeeze(2)
    soft_attn_weights = F.softmax(attn_weights, 1)
    context = torch.bmm(lstm_output.transpose(1, 2), 
     soft_attn_weights.unsqueeze(2)).squeeze(2)
    return context, soft_attn_weights.cpu().data.numpy()
def forward(self, text):       
    output, (hn, cn) = self.lstm(text)    
    hn = torch.cat((hn[-2,:,:], hn[-1,:,:]), dim = 1)
    attn_output, attention = self.attention_net(output, hn)
    return self.fc(attn_output), attention`

我使用 LSTM + 注意。模型不学习class = 3,但一直只给我一堂课。

4

0 回答 0