目标:优化和更好地理解 BiLSTM
我有一个正在工作的 BiLSTM。然而,第一个时代的val_score: 0.0
.
我认为这个直截了当的问题是由于我对这个附加层的训练处理不当造成的。
问题:
- 什么可能会导致 BiLSTM 出现这样的问题?
- 我没有/错误地实施了什么?
代码:
from argparse import ArgumentParser
import torchmetrics
import pytorch_lightning as pl
import torch
import torch.nn as nn
import torch.nn.functional as F
class LSTMClassifier(nn.Module):
def __init__(self,
num_classes,
batch_size=10,
embedding_dim=100,
hidden_dim=50,
vocab_size=128):
super(LSTMClassifier, self).__init__()
initrange = 0.1
self.num_labels = num_classes
n = len(self.num_labels)
self.hidden_dim = hidden_dim
self.batch_size = batch_size
self.num_layers = 1
self.word_embeddings = nn.Embedding(vocab_size, embedding_dim)
self.word_embeddings.weight.data.uniform_(-initrange, initrange)
self.lstm = nn.LSTM(input_size=embedding_dim, hidden_size=hidden_dim, num_layers=self.num_layers, batch_first=True, bidirectional=True) # !
#self.classifier = nn.Linear(hidden_dim, self.num_labels[0])
self.classifier = nn.Linear(2 * hidden_dim, self.num_labels[0]) # !
def repackage_hidden(h):
"""Wraps hidden states in new Tensors, to detach them from their history."""
if isinstance(h, torch.Tensor):
return h.detach()
else:
return tuple(repackage_hidden(v) for v in h)
def forward(self, sentence, labels=None):
embeds = self.word_embeddings(sentence)
# lstm_out, _ = self.lstm(embeds) # lstm_out - 2 tensors, _ - hidden layer
lstm_out, hidden = self.lstm(embeds)
# Calculate number of directions
self.num_directions = 2 if self.lstm.bidirectional == True else 1
# Extract last hidden state
# final_state = hidden.view(self.num_layers, self.num_directions, self.batch_size, self.hidden_dim)[-1]
final_state = hidden[0].view(self.num_layers, self.num_directions, self.batch_size, self.hidden_dim)[-1]
# Handle directions
final_hidden_state = None
if self.num_directions == 1:
final_hidden_state = final_state.squeeze(0)
elif self.num_directions == 2:
h_1, h_2 = final_state[0], final_state[1]
# final_hidden_state = h_1 + h_2 # Add both states (requires changes to the input size of first linear layer + attention layer)
final_hidden_state = torch.cat((h_1, h_2), 1) # Concatenate both states
self.linear_dims = [0]
# Define set of fully connected layers (Linear Layer + Activation Layer) * #layers
self.linears = nn.ModuleList()
for i in range(0, len(self.linear_dims)-1):
linear_layer = nn.Linear(self.linear_dims[i], self.linear_dims[i+1])
self.init_weights(linear_layer)
self.linears.append(linear_layer)
if i == len(self.linear_dims) - 1:
break # no activation after output layer!!!
self.linears.append(nn.ReLU())
X = final_hidden_state
# Push through linear layers
for l in self.linears:
X = l(X)
print('type(X)', type(X))
print('len(X)', len(X))
print('X.shape', X.shape)
print('type(labels)', type(labels))
print('len(labels)', len(labels))
print('labels', labels)
print('type(hidden[0])', type(hidden[0])) # hidden[0] - tensor
print('len(hidden[0])', len(hidden[0]))
print('hidden[0].shape', hidden[0].shape)
print('hidden[0]', labels)
logits = self.classifier(X) # ! # torch.flip(lstm_out[:,-1,:], [0, 1]) - 1 tensor
print('type(logits)', type(logits))
print('len(logits)', len(logits))
print('logits.shape', logits.shape)
loss = None
if labels:
print("len(self.num_labels)", len(self.num_labels))
print("self.num_labels[0]", self.num_labels[0])
print("len(labels[0].view(-1))", len(labels[0].view(-1)))
loss = F.cross_entropy(logits.view(-1, self.num_labels[0]), labels[0].view(-1))
return loss, logits
class LSTMTaggerModel(pl.LightningModule):
def __init__(
self,
num_classes,
class_map,
from_checkpoint=False,
model_name='last.ckpt',
learning_rate=3e-6,
**kwargs,
):
super().__init__()
self.save_hyperparameters()
self.learning_rate = learning_rate
self.model = LSTMClassifier(num_classes=num_classes)
# self.model.load_state_dict(torch.load(model_name), strict=False) # !
self.class_map = class_map
self.num_classes = num_classes
self.valid_acc = torchmetrics.Accuracy()
self.valid_f1 = torchmetrics.F1()
def forward(self, *input, **kwargs):
return self.model(*input, **kwargs)
def training_step(self, batch, batch_idx):
x, y_true = batch
loss, _ = self(x, labels=y_true)
self.log('train_loss', loss)
return loss
def validation_step(self, batch, batch_idx):
x, y_true = batch
_, y_pred = self(x, labels=y_true)
preds = torch.argmax(y_pred, axis=1)
self.valid_acc(preds, y_true[0])
self.log('val_acc', self.valid_acc, prog_bar=True)
self.valid_f1(preds, y_true[0])
self.log('f1', self.valid_f1, prog_bar=True)
def configure_optimizers(self):
'Prepare optimizer and schedule (linear warmup and decay)'
opt = torch.optim.Adam(params=self.parameters(), lr=self.learning_rate)
sch = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=10)
return [opt], [sch]
def training_epoch_end(self, training_step_outputs):
avg_loss = torch.tensor([x['loss']
for x in training_step_outputs]).mean()
self.log('train_loss', avg_loss)
print(f'###score: train_loss### {avg_loss}')
def validation_epoch_end(self, val_step_outputs):
acc = self.valid_acc.compute()
f1 = self.valid_f1.compute()
self.log('val_score', acc)
self.log('f1', f1)
print(f'###score: val_score### {acc}')
def add_model_specific_args(parent_parser):
parser = parent_parser.add_argument_group("OntologyTaggerModel")
parser = ArgumentParser(parents=[parent_parser], add_help=False)
parser.add_argument("--learning_rate", default=2e-3, type=float)
return parent_parser
追溯:
print
语句重复。
Global seed set to 42
/home/ec2-user/anaconda3/envs/pytorch_latest_p36/lib/python3.6/site-packages/deprecate/deprecation.py:115: FutureWarning: The `F1` was deprecated since v0.7 in favor of `torchmetrics.classification.f_beta.F1Score`. It will be removed in v0.8.
stream(template_mgs % msg_args)
/home/ec2-user/anaconda3/envs/pytorch_latest_p36/lib/python3.6/site-packages/pytorch_lightning/utilities/distributed.py:52: UserWarning: ModelCheckpoint(save_last=True, monitor=None) is a redundant configuration. You can save the last checkpoint with ModelCheckpoint(save_top_k=None, monitor=None).
warnings.warn(*args, **kwargs)
GPU available: False, used: False
TPU available: False, using: 0 TPU cores
| Name | Type | Params
---------------------------------------------
0 | model | LSTMClassifier | 77.4 K
1 | valid_acc | Accuracy | 0
2 | valid_f1 | F1 | 0
---------------------------------------------
77.4 K Trainable params
0 Non-trainable params
77.4 K Total params
0.310 Total estimated model params size (MB)
/home/ec2-user/anaconda3/envs/pytorch_latest_p36/lib/python3.6/site-packages/pytorch_lightning/utilities/distributed.py:52: UserWarning: Your val_dataloader has `shuffle=True`, it is best practice to turn this off for validation and test dataloaders.
warnings.warn(*args, **kwargs)
/home/ec2-user/anaconda3/envs/pytorch_latest_p36/lib/python3.6/site-packages/pytorch_lightning/utilities/distributed.py:52: UserWarning: The dataloader, val dataloader 0, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 4 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.
warnings.warn(*args, **kwargs)
Validation sanity check: 0it [00:00, ?it/s]
/home/ec2-user/anaconda3/envs/pytorch_latest_p36/lib/python3.6/site-packages/torchmetrics/utilities/prints.py:36: UserWarning: The ``compute`` method of metric Accuracy was called before the ``update`` method which may lead to errors, as metric states have not yet been updated.
warnings.warn(*args, **kwargs)
/home/ec2-user/anaconda3/envs/pytorch_latest_p36/lib/python3.6/site-packages/torchmetrics/utilities/prints.py:36: UserWarning: The ``compute`` method of metric F1 was called before the ``update`` method which may lead to errors, as metric states have not yet been updated.
warnings.warn(*args, **kwargs)
type(X) <class 'torch.Tensor'>
len(X) 10
X.shape torch.Size([10, 100])
type(labels) <class 'list'>
len(labels) 1
labels [tensor([ 2, 31, 26, 37, 22, 5, 31, 36, 5, 10])]
type(hidden[0]) <class 'torch.Tensor'>
len(hidden[0]) 2
hidden[0].shape torch.Size([2, 10, 50])
hidden[0] [tensor([ 2, 31, 26, 37, 22, 5, 31, 36, 5, 10])]
type(logits) <class 'torch.Tensor'>
len(logits) 10
logits.shape torch.Size([10, 38])
len(self.num_labels) 1
self.num_labels[0] 38
len(labels[0].view(-1)) 10
...