def __init__(self):
super().__init__()
self.lstm = nn.LSTM(input_dim,
hidden_dim,
num_layers=num_layers,
bidirectional=bidirectional,
dropout=dropout,
batch_first=True)
self.fc = nn.Linear(hidden_dim * 2, num_classes)
def attention_net(self, lstm_output, final_state):
hidden = final_state.unsqueeze(2)
attn_weights = torch.bmm(lstm_output, hidden).squeeze(2)
soft_attn_weights = F.softmax(attn_weights, 1)
context = torch.bmm(lstm_output.transpose(1, 2),
soft_attn_weights.unsqueeze(2)).squeeze(2)
return context, soft_attn_weights.cpu().data.numpy()
def forward(self, text):
output, (hn, cn) = self.lstm(text)
hn = torch.cat((hn[-2,:,:], hn[-1,:,:]), dim = 1)
attn_output, attention = self.attention_net(output, hn)
return self.fc(attn_output), attention`
我使用 LSTM + 注意。模型不学习class = 3,但一直只给我一堂课。