0

我正在尝试使用 BERT 实现情感分析代码(正面或负面标签),我想添加一个 BiLSTM 层,看看我是否可以提高 HuggingFace 预训练模型的准确性。我有以下代码和几个问题:

import numpy as np
import pandas as pd
from sklearn import metrics
import transformers
import torch
from torch.utils.data import Dataset, DataLoader, RandomSampler, SequentialSampler
from transformers import BertTokenizer, BertModel, BertConfig
from torch import cuda
import re
import torch.nn as nn

device = 'cuda' if cuda.is_available() else 'cpu'
MAX_LEN = 200
TRAIN_BATCH_SIZE = 8
VALID_BATCH_SIZE = 4
EPOCHS = 1
LEARNING_RATE = 1e-05 #5e-5, 3e-5 or 2e-5
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

class CustomDataset(Dataset):
 def __init__(self, dataframe, tokenizer, max_len):
  self.tokenizer = tokenizer
  self.data = dataframe
  self.comment_text = dataframe.review
  self.targets = self.data.sentiment
  self.max_len = max_len
 def __len__(self):
  return len(self.comment_text)
 def __getitem__(self, index):
  comment_text = str(self.comment_text[index])
  comment_text = " ".join(comment_text.split())

  inputs = self.tokenizer.encode_plus(comment_text,None,add_special_tokens=True,max_length=self.max_len,
   pad_to_max_length=True,return_token_type_ids=True)
  ids = inputs['input_ids']
  mask = inputs['attention_mask']
  token_type_ids = inputs["token_type_ids"]

  return {
   'ids': torch.tensor(ids, dtype=torch.long),
   'mask': torch.tensor(mask, dtype=torch.long),
   'token_type_ids': torch.tensor(token_type_ids, dtype=torch.long),
   'targets': torch.tensor(self.targets[index], dtype=torch.float)
  }
train_size = 0.8
train_dataset=df.sample(frac=train_size,random_state=200)
test_dataset=df.drop(train_dataset.index).reset_index(drop=True)
train_dataset = train_dataset.reset_index(drop=True)

print("FULL Dataset: {}".format(df.shape))
print("TRAIN Dataset: {}".format(train_dataset.shape))
print("TEST Dataset: {}".format(test_dataset.shape))

training_set = CustomDataset(train_dataset, tokenizer, MAX_LEN)
testing_set = CustomDataset(test_dataset, tokenizer, MAX_LEN)
train_params = {'batch_size': TRAIN_BATCH_SIZE,'shuffle': True,'num_workers': 0}
test_params = {'batch_size': VALID_BATCH_SIZE,'shuffle': True,'num_workers': 0}
training_loader = DataLoader(training_set, **train_params)
testing_loader = DataLoader(testing_set, **test_params)


class BERTClass(torch.nn.Module):
 def __init__(self):
   super(BERTClass, self).__init__()
   self.bert = BertModel.from_pretrained('bert-base-uncased',return_dict=False, num_labels =2)
   self.lstm = nn.LSTM(768, 256, batch_first=True, bidirectional=True)
   self.linear = nn.Linear(256*2,2)

 def forward(self, ids , mask,token_type_ids):
  sequence_output, pooled_output = self.bert(ids, attention_mask=mask, token_type_ids = token_type_ids)
  lstm_output, (h, c) = self.lstm(sequence_output)  ## extract the 1st token's embeddings
  hidden = torch.cat((lstm_output[:, -1, :256], lstm_output[:, 0, 256:]), dim=-1)
  linear_output = self.linear(lstm_output[:, -1].view(-1, 256 * 2))

  return linear_output

model = BERTClass()
model.to(device)
print(model)
def loss_fn(outputs, targets):
 return torch.nn.BCEWithLogitsLoss()(outputs, targets)
optimizer = torch.optim.Adam(params =  model.parameters(), lr=LEARNING_RATE)

def train(epoch):
 model.train()
 for _, data in enumerate(training_loader, 0):
  ids = data['ids'].to(device, dtype=torch.long)
  mask = data['mask'].to(device, dtype=torch.long)
  token_type_ids = data['token_type_ids'].to(device, dtype=torch.long)
  targets = data['targets'].to(device, dtype=torch.float)
  outputs = model(ids, mask, token_type_ids)
  optimizer.zero_grad()
  loss = loss_fn(outputs, targets)
  if _ % 5000 == 0:
   print(f'Epoch: {epoch}, Loss:  {loss.item()}')
  optimizer.zero_grad()
  loss.backward()
  optimizer.step()

for epoch in range(EPOCHS):
  train(epoch)

所以在上面的代码中我遇到了错误: Target size (torch.Size([8])) must be the same as input size (torch.Size([8, 2]))。在线检查并尝试使用targets = targets.unsqueeze(2),但后来我收到另一个错误,我必须使用 [-2,1] 中的值进行解压。我还尝试将损失函数修改为

def loss_fn(outputs, targets):
 return torch.nn.BCELoss()(outputs, targets)

但我仍然收到同样的错误。有人可以建议是否有解决此问题的方法?或者我该怎么做才能使这项工作正常进行?提前谢谢了。

4

0 回答 0