我有一个带有 LSTM 的类,然后是 nn.ModuleList,我在其中定义了另外两个 LSTM 层,但是循环的前向函数失败并出现错误“forward() 需要 1 个位置参数,但给出了 3 个”。
分享了下面的代码和错误。当我尝试传递较早的层输出、单元状态和隐藏状态值但 nn.ModuleList 不允许时,它似乎失败了。谁能帮我解决这个问题?
代码:
class RNN(nn.Module):
def __init__(self, vocab_size, embedding_dim, hidden_dim, output_dim, n_layers,
bidirectional, dropout, pad_idx,nadd_layers=2):
super().__init__()
self.addl_layer_list = [nn.ModuleList([nn.LSTM(hidden_dim*2,hidden_dim,num_layers=1,bidirectional=True,dropout=dropout), nn.LSTM(hidden_dim*2,hidden_dim,num_layers=1,bidirectional=True,dropout=dropout)])]
self.embedding = nn.Embedding(vocab_size, embedding_dim, padding_idx = pad_idx)
self.nlayers = nadd_layers
# self.layers = nn.ModuleList([nn.Linear(in_f, out_f) for in_f, out_f in zip(sizes, sizes[1:])])
# for layer in range(self.nlayers):
# layer
self.rnn1 = nn.LSTM(embedding_dim,
hidden_dim,
num_layers=1,
bidirectional=True,
dropout=dropout)
self.addl_layers = nn.ModuleList(self.addl_layer_list)
self.fc = nn.Linear(hidden_dim, output_dim)
self.dropout = nn.Dropout(dropout)
def forward(self, text, text_lengths):
#text = [sent len, batch size] https://stackoverflow.com/questions/49224413/difference-between-1-lstm-with-num-layers-2-and-2-lstms-in-pytorch
embedded = self.dropout(self.embedding(text))
#pack sequence
packed_embedded = nn.utils.rnn.pack_padded_sequence(embedded, text_lengths)
packed_output, (hidden, cell) = self.rnn1(packed_embedded)
hidden_final = hidden
hidden_set = True
for layer in range(self.nlayers):
packed_output1, (hidden1, cell1) = self.addl_layers[layer](packed_output,(hidden, cell))
错误:
TypeError Traceback (most recent call last)
<ipython-input-164-4185e35f9156> in <module>()
7 start_time = time.time()
8
----> 9 train_loss, train_acc = train(model, train_iterator, optimizer, criterion)
10 valid_loss, valid_acc = evaluate(model, valid_iterator, criterion)
11
3 frames
/usr/local/lib/python3.6/dist-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
725 result = self._slow_forward(*input, **kwargs)
726 else:
--> 727 result = self.forward(*input, **kwargs)
728 for hook in itertools.chain(
729 _global_forward_hooks.values(),
TypeError: forward() takes 1 positional argument but 3 were given