网络的结构必须如下:
(lstm): LSTM(1, 64, batch_first=True)
(fc1):线性(in_features=64,out_features=32,bias=True)
(relu): ReLU()
(fc2):线性(in_features=32,out_features=5,bias=True)
我写了这段代码:
class LSTMClassifier(nn.Module):
def __init__(self):
super(LSTMClassifier, self).__init__()
self.lstm = nn.LSTM(1, 64, batch_first=True)
self.fc1 = nn.Linear(in_features=64, out_features=32, bias=True)
self.relu = nn.ReLU()
self.fc2 = nn.Linear(in_features=32, out_features=5, bias=True)
def forward(self, x):
x = torch.tanh(self.lstm(x)[0])
x = self.fc1(x)
x = F.relu(x)
x = self.fc2(x)
这是为了测试:
(batch_data, batch_label) = next (iter (train_loader))
model = LSTMClassifier().to(device)
output = model (batch_data.to(device)).cpu()
assert output.shape == (batch_size, 5)
print ("passed")
错误是:
----> 3 输出 = 模型 (batch_data.to(device)).cpu()
5 帧 /usr/local/lib/python3.7/dist-packages/torch/nn/modules/rnn.py in check_input(self, input, batch_sizes) 201 raise RuntimeError( 202 'input must have {} dimensions, got { }'.format(--> 203 expected_input_dim, input.dim())) 204 if self.input_size != input.size(-1): 205 raise RuntimeError(
RuntimeError:输入必须有 3 个维度,得到 2
我的问题是什么?