0

我知道对于多类,F1 (micro) 与 Accuracy 相同。我的目标是在 Torch Lightning 中测试二进制分类,但总是得到相同的 F1 和准确度。

为了获得更多详细信息,我在GIST分享了我的代码,在那里我使用了MUTAG数据集。以下是我想提出讨论的一些重要部分

我计算精度和 F1 的函数(第 28-40 行)

def evaluate(self, batch, stage=None):
        y_hat = self(batch.x, batch.edge_index, batch.batch)
        loss = self.criterion(y_hat, batch.y)
        preds = torch.argmax(y_hat.softmax(dim=1), dim=1)
        acc = accuracy(preds, batch.y)
        f1_score = f1(preds, batch.y)

        if stage:
            self.log(f"{stage}_loss", loss, on_step=True, on_epoch=True, logger=True)
            self.log(f"{stage}_acc", acc, on_step=True, on_epoch=True, logger=True)
            self.log(f"{stage}_f1", f1_score, on_step=True, on_epoch=True, logger=True)

        return loss

为了检查,我在第 35 行放置了一个检查点,得到acc=0.5, f1_score=0.5, whilepredictionlabel分别是

preds = tensor([1, 1, 1, 0, 1, 1, 1, 1, 0, 0])
batch.y = tensor([1, 0, 1, 1, 0, 1, 0, 1, 1, 0])

使用这些值,我运行一个笔记本来仔细检查scikit-learn

from sklearn.metrics import f1_score
y_hat = [1, 1, 1, 0, 1, 1, 1, 1, 0, 0]
y = [1, 0, 1, 1, 0, 1, 0, 1, 1, 0]
f1_score(y_hat, y, average='binary') # got 0.6153846153846153
accuracy_score(y_hat, y) # 0.5

与评估的代码相比,我得到了不同的结果。此外,我再次验证torch,有趣的是,我得到了正确的结果

from torchmetrics.functional import accuracy, f1
import torch
f1(torch.Tensor(y_hat), torch.LongTensor(y)) # tensor(0.6154)
accuracy(torch.Tensor(pred), torch.LongTensor(true)) # tensor(0.5000)

我想不知何故torch-lightning将我的计算视为一项多类任务。我的问题是如何纠正它的行为?

4

0 回答 0