Pure PyTorch不提供开箱即用的指标,但您自己定义这些指标非常容易。
也没有“从模型中提取指标”之类的东西。度量是度量,它们测量(在这种情况下是鉴别器的准确性),它们不是模型固有的。
二进制精度
在您的情况下,您正在寻找二进制精度指标。下面的代码适用于logits
(由 输出的非标准化概率discriminator
,可能是nn.Linear
没有激活的最后一层)或probabilities
(最后nn.Linear
是sigmoid
激活):
import typing
import torch
class BinaryAccuracy:
def __init__(
self,
logits: bool = True,
reduction: typing.Callable[
[
torch.Tensor,
],
torch.Tensor,
] = torch.mean,
):
self.logits = logits
if logits:
self.threshold = 0
else:
self.threshold = 0.5
self.reduction = reduction
def __call__(self, y_pred, y_true):
return self.reduction(((y_pred > self.threshold) == y_true.bool()).float())
用法:
metric = BinaryAccuracy()
target = torch.randint(2, size=(64,))
outputs = torch.randn(size=(64, 1))
print(metric(outputs, target))
PyTorch Lightning 或其他第三方
您还可以在 PyTorch 之上使用PyTorch Lightning或其他框架,这些框架定义了准确度等指标