0

目标:优化和更好地理解 BiLSTM

我有一个正在工作的 BiLSTM。然而,第一个时代的val_score: 0.0.

我认为这个直截了当的问题是由于我对这个附加层的训练处理不当造成的。

问题:

  • 什么可能会导致 BiLSTM 出现这样的问题?
  • 我没有/错误地实施了什么?

代码:

from argparse import ArgumentParser

import torchmetrics
import pytorch_lightning as pl
import torch
import torch.nn as nn
import torch.nn.functional as F

class LSTMClassifier(nn.Module):

    def __init__(self, 
        num_classes, 
        batch_size=10,
        embedding_dim=100, 
        hidden_dim=50, 
        vocab_size=128):

        super(LSTMClassifier, self).__init__()

        initrange = 0.1

        self.num_labels = num_classes
        n = len(self.num_labels)
        self.hidden_dim = hidden_dim
        self.batch_size = batch_size
        
        self.num_layers = 1

        self.word_embeddings = nn.Embedding(vocab_size, embedding_dim)
        self.word_embeddings.weight.data.uniform_(-initrange, initrange)
        self.lstm = nn.LSTM(input_size=embedding_dim, hidden_size=hidden_dim, num_layers=self.num_layers, batch_first=True, bidirectional=True)  # !
        #self.classifier = nn.Linear(hidden_dim, self.num_labels[0])
        self.classifier = nn.Linear(2 * hidden_dim, self.num_labels[0])  # !


    def repackage_hidden(h):
        """Wraps hidden states in new Tensors, to detach them from their history."""

        if isinstance(h, torch.Tensor):
            return h.detach()
        else:
            return tuple(repackage_hidden(v) for v in h)


    def forward(self, sentence, labels=None):
        embeds = self.word_embeddings(sentence)
        # lstm_out, _ = self.lstm(embeds)  # lstm_out - 2 tensors, _ - hidden layer
        lstm_out, hidden = self.lstm(embeds)
        
        # Calculate number of directions
        self.num_directions = 2 if self.lstm.bidirectional == True else 1
        
        # Extract last hidden state
        # final_state = hidden.view(self.num_layers, self.num_directions, self.batch_size, self.hidden_dim)[-1]
        final_state = hidden[0].view(self.num_layers, self.num_directions, self.batch_size, self.hidden_dim)[-1]
        # Handle directions
        final_hidden_state = None
        if self.num_directions == 1:
            final_hidden_state = final_state.squeeze(0)
        elif self.num_directions == 2:
            h_1, h_2 = final_state[0], final_state[1]
            # final_hidden_state = h_1 + h_2               # Add both states (requires changes to the input size of first linear layer + attention layer)
            final_hidden_state = torch.cat((h_1, h_2), 1)  # Concatenate both states
        
        self.linear_dims = [0]
        
        # Define set of fully connected layers (Linear Layer + Activation Layer) * #layers
        self.linears = nn.ModuleList()
        for i in range(0, len(self.linear_dims)-1):
            linear_layer = nn.Linear(self.linear_dims[i], self.linear_dims[i+1])
            self.init_weights(linear_layer)
            self.linears.append(linear_layer)
            if i == len(self.linear_dims) - 1:
                break  # no activation after output layer!!!
            self.linears.append(nn.ReLU())
        
        X = final_hidden_state
        
        # Push through linear layers
        for l in self.linears:
            X = l(X)

        print('type(X)', type(X))
        print('len(X)', len(X))
        print('X.shape', X.shape)
        
        print('type(labels)', type(labels))
        print('len(labels)', len(labels))
        print('labels', labels)
        
        print('type(hidden[0])', type(hidden[0]))  # hidden[0] - tensor
        print('len(hidden[0])', len(hidden[0]))
        print('hidden[0].shape', hidden[0].shape)
        print('hidden[0]', labels)

        
        logits = self.classifier(X)  # !  # torch.flip(lstm_out[:,-1,:], [0, 1]) - 1 tensor
        print('type(logits)', type(logits))
        print('len(logits)', len(logits))
        print('logits.shape', logits.shape)
        
        loss = None
        if labels:
            print("len(self.num_labels)", len(self.num_labels))
            print("self.num_labels[0]", self.num_labels[0])
            print("len(labels[0].view(-1))", len(labels[0].view(-1)))
            loss = F.cross_entropy(logits.view(-1, self.num_labels[0]), labels[0].view(-1))
            
        return loss, logits


class LSTMTaggerModel(pl.LightningModule):
    def __init__(
        self,
        num_classes,
        class_map,
        from_checkpoint=False,
        model_name='last.ckpt',
        learning_rate=3e-6,
        **kwargs,
    ):

        super().__init__()
        self.save_hyperparameters()
        self.learning_rate = learning_rate
        self.model = LSTMClassifier(num_classes=num_classes)
        # self.model.load_state_dict(torch.load(model_name), strict=False)  # !
        self.class_map = class_map
        self.num_classes = num_classes
        self.valid_acc = torchmetrics.Accuracy()
        self.valid_f1 = torchmetrics.F1()


    def forward(self, *input, **kwargs):
        return self.model(*input, **kwargs)

    def training_step(self, batch, batch_idx):
        x, y_true = batch
        loss, _ = self(x, labels=y_true)
        self.log('train_loss', loss)
        return loss

    def validation_step(self, batch, batch_idx):
        x, y_true = batch
        _, y_pred = self(x, labels=y_true)
        preds = torch.argmax(y_pred, axis=1)
        self.valid_acc(preds, y_true[0])
        self.log('val_acc', self.valid_acc, prog_bar=True)
        self.valid_f1(preds, y_true[0])
        self.log('f1', self.valid_f1, prog_bar=True)     

    def configure_optimizers(self):
        'Prepare optimizer and schedule (linear warmup and decay)'
        opt = torch.optim.Adam(params=self.parameters(), lr=self.learning_rate)
        sch = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=10)
        return [opt], [sch]

    def training_epoch_end(self, training_step_outputs):
        avg_loss = torch.tensor([x['loss']
                                 for x in training_step_outputs]).mean()
        self.log('train_loss', avg_loss)
        print(f'###score: train_loss### {avg_loss}')

    def validation_epoch_end(self, val_step_outputs):
        acc = self.valid_acc.compute()
        f1 = self.valid_f1.compute()
        self.log('val_score', acc)
        self.log('f1', f1)
        print(f'###score: val_score### {acc}')

    def add_model_specific_args(parent_parser):
        parser = parent_parser.add_argument_group("OntologyTaggerModel")       
        parser = ArgumentParser(parents=[parent_parser], add_help=False)
        parser.add_argument("--learning_rate", default=2e-3, type=float)
        return parent_parser

追溯:

print语句重复。

Global seed set to 42
/home/ec2-user/anaconda3/envs/pytorch_latest_p36/lib/python3.6/site-packages/deprecate/deprecation.py:115: FutureWarning: The `F1` was deprecated since v0.7 in favor of `torchmetrics.classification.f_beta.F1Score`. It will be removed in v0.8.
  stream(template_mgs % msg_args)
/home/ec2-user/anaconda3/envs/pytorch_latest_p36/lib/python3.6/site-packages/pytorch_lightning/utilities/distributed.py:52: UserWarning: ModelCheckpoint(save_last=True, monitor=None) is a redundant configuration. You can save the last checkpoint with ModelCheckpoint(save_top_k=None, monitor=None).
  warnings.warn(*args, **kwargs)
GPU available: False, used: False
TPU available: False, using: 0 TPU cores

  | Name      | Type           | Params
---------------------------------------------
0 | model     | LSTMClassifier | 77.4 K
1 | valid_acc | Accuracy       | 0     
2 | valid_f1  | F1             | 0     
---------------------------------------------
77.4 K    Trainable params
0         Non-trainable params
77.4 K    Total params
0.310     Total estimated model params size (MB)
/home/ec2-user/anaconda3/envs/pytorch_latest_p36/lib/python3.6/site-packages/pytorch_lightning/utilities/distributed.py:52: UserWarning: Your val_dataloader has `shuffle=True`, it is best practice to turn this off for validation and test dataloaders.
  warnings.warn(*args, **kwargs)
/home/ec2-user/anaconda3/envs/pytorch_latest_p36/lib/python3.6/site-packages/pytorch_lightning/utilities/distributed.py:52: UserWarning: The dataloader, val dataloader 0, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 4 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.
  warnings.warn(*args, **kwargs)
Validation sanity check: 0it [00:00, ?it/s]
/home/ec2-user/anaconda3/envs/pytorch_latest_p36/lib/python3.6/site-packages/torchmetrics/utilities/prints.py:36: UserWarning: The ``compute`` method of metric Accuracy was called before the ``update`` method which may lead to errors, as metric states have not yet been updated.
  warnings.warn(*args, **kwargs)
/home/ec2-user/anaconda3/envs/pytorch_latest_p36/lib/python3.6/site-packages/torchmetrics/utilities/prints.py:36: UserWarning: The ``compute`` method of metric F1 was called before the ``update`` method which may lead to errors, as metric states have not yet been updated.
  warnings.warn(*args, **kwargs)

type(X) <class 'torch.Tensor'>
len(X) 10
X.shape torch.Size([10, 100])
type(labels) <class 'list'>
len(labels) 1
labels [tensor([ 2, 31, 26, 37, 22,  5, 31, 36,  5, 10])]
type(hidden[0]) <class 'torch.Tensor'>
len(hidden[0]) 2
hidden[0].shape torch.Size([2, 10, 50])
hidden[0] [tensor([ 2, 31, 26, 37, 22,  5, 31, 36,  5, 10])]
type(logits) <class 'torch.Tensor'>
len(logits) 10
logits.shape torch.Size([10, 38])
len(self.num_labels) 1
self.num_labels[0] 38
len(labels[0].view(-1)) 10
...
4

0 回答 0