1

使用 pytorch-lightning 和 transformers,我在德国服务票上微调了一个 Bert 模型。数据集的大小如下:

FULL Dataset: (1220, 2)
TRAIN Dataset: (854, 2)
VAL Dataset: (366, 2)

每张票可以恰好属于 10 个类别中的 1 个。这就是为什么我的模型在 def init中初始化的原因,例如:

#changing the configuration to X lables instead of 2
self.bert = transformers.BertModel.from_pretrained(MODEL_NAME)
self.drop = th.nn.Dropout(p=0.1)
self.out = th.nn.Linear(self.bert.config.hidden_size, NUM_LABELS)
self.softmax = th.nn.Softmax(dim=1)
self.loss = th.nn.CrossEntropyLoss(reduction="none")

这为每个样本产生了 10 个类别的概率分布。由于前向功能是:

 def forward(self, input_ids, mask):
    _, pooled_output = self.bert(
                                input_ids=input_ids,
                                attention_mask=mask
                        )
    output= self.drop(pooled_output)
    output = self.out(output)
    return self.softmax(output)

作为损失函数,torch.nn.CrossEntropyLoss 被定义并在 training_step 中被调用。Batch_Size 为 16,logits.shape = [16,10] 和 batch['targets'].shape = [16] 和 batch['targets'] = [1,5,2,4,8,6,9 ,0,0,1,2,7,7,7,5,3]。CrossEntropyLoss 是正确的损失函数吗?优化器是否还在工作?

 def training_step(self, batch, batch_idx):
    logits = self.forward(batch['input_ids'], batch['mask']).squeeze()
    loss = self.loss(logits, batch['targets']).mean()
    return {'loss': loss, 'log': {'train_loss': loss}}

验证步骤也是如此:

  def validation_step(self,batch, batch_idx):
      logits = self.forward(batch['input_ids'], batch['mask']).squeeze()
      acc = (logits.argmax(-1) == batch['targets']).float()
      loss = self.loss(logits, batch['targets'])
      return {'loss': loss, 'acc': acc}

最后,无论得到什么输入序列,模型都会产生相同的概率。这是欠拟合吗?

一个例子是:

model.eval()
text = "Warum kann mein Kollege in SAP keine Transaktionen mehr     ausführen?"
input = tokenizer(            
   text,
   None,
   add_special_tokens=True,
   max_length=200,
   pad_to_max_length=True,
   return_token_type_ids=True,
   truncation=True,
   padding='max_length',
   return_tensors="pt"
   )

input = input.to(device)

out = model(input_ids=input['input_ids'], mask=input['attention_mask'])
text, out

    --> ('Warum kann mein Kollege in SAP keine Transaktionen mehr ausführen?',
 tensor([[2.9374e-03, 3.1926e-03, 8.7949e-03, 3.0573e-01, 2.6428e-04, 5.2946e-02,
      2.4758e-01, 6.2161e-03, 3.6930e-01, 3.0384e-03]], device='cuda:0',
    grad_fn=<SoftmaxBackward>))

另一个例子是:

    ('Auf meinem Telefon erscheinen keine Nummern mehr.',
 tensor([[2.9374e-03, 3.1926e-03, 8.7949e-03, 3.0573e-01, 2.6428e-04, 5.2946e-02,
          2.4758e-01, 6.2161e-03, 3.6930e-01, 3.0384e-03]], device='cuda:0',
        grad_fn=<SoftmaxBackward>))

它是德语,但如您所见,预测完全匹配。这是一个耻辱 :D 我的问题。

4

0 回答 0