我正在尝试将注意力机制添加到堆叠 LSTM 实现https://github.com/salesforce/awd-lstm-lm
所有在线示例都使用编码器-解码器架构,我不想使用(我是否必须使用注意力机制?)。
def __init__(self, rnn_type, ntoken, ninp, nhid, nlayers, dropout=0.5, dropouth=0.5, dropouti=0.5, dropoute=0.1, wdrop=0, tie_weights=False):
super(RNNModel, self).__init__()
self.encoder = nn.Embedding(ntoken, ninp)
self.rnns = [torch.nn.LSTM(ninp if l == 0 else nhid, nhid if l != nlayers - 1 else (ninp if tie_weights else nhid), 1, dropout=0) for l in range(nlayers)]
for rnn in self.rnns:
rnn.linear = WeightDrop(rnn.linear, ['weight'], dropout=wdrop)
self.rnns = torch.nn.ModuleList(self.rnns)
self.attn_fc = torch.nn.Linear(ninp, 1)
self.decoder = nn.Linear(nhid, ntoken)
self.init_weights()
def attention(self, rnn_out, state):
state = torch.transpose(state, 1,2)
weights = torch.bmm(rnn_out, state)# torch.bmm(rnn_out, state)
weights = torch.nn.functional.softmax(weights)#.squeeze(2)).unsqueeze(2)
rnn_out_t = torch.transpose(rnn_out, 1, 2)
bmmed = torch.bmm(rnn_out_t, weights)
bmmed = bmmed.squeeze(2)
return bmmed
def forward(self, input, hidden, return_h=False, decoder=False, encoder_outputs=None):
emb = embedded_dropout(self.encoder, input, dropout=self.dropoute if self.training else 0)
emb = self.lockdrop(emb, self.dropouti)
new_hidden = []
raw_outputs = []
outputs = []
for l, rnn in enumerate(self.rnns):
temp = []
for item in emb:
item = item.unsqueeze(0)
raw_output, new_h = rnn(item, hidden[l])
raw_output = self.attention(raw_output, new_h[0])
temp.append(raw_output)
raw_output = torch.stack(temp)
raw_output = raw_output.squeeze(1)
new_hidden.append(new_h)
raw_outputs.append(raw_output)
if l != self.nlayers - 1:
raw_output = self.lockdrop(raw_output, self.dropouth)
outputs.append(raw_output)
hidden = new_hidden
output = self.lockdrop(raw_output, self.dropout)
outputs.append(output)
outputs = torch.stack(outputs).squeeze(0)
outputs = torch.transpose(outputs, 2,1)
output = output.transpose(2,1)
output = output.contiguous()
decoded = self.decoder(output.view(output.size(0)*output.size(1), output.size(2)))
result = decoded.view(output.size(0), output.size(1), decoded.size(1))
if return_h:
return result, hidden, raw_outputs, outputs
return result, hidden
这个模型正在训练,但与没有注意模型的模型相比,我的损失相当高。