0

我正在使用 PyTorch 实现 DCGAN。

它的效果很好,因为我可以获得质量合理的生成图像,但是现在我想通过使用指标来评估 GAN 模型的健康状况,主要是本指南中介绍的指标https://machinelearningmastery.com/practical-guide-to- GAN故障模式/

他们的实现使用 Keras,该 SDK 允许您在编译模型时定义所需的指标,请参阅https://keras.io/api/models/model/。在这种情况下,鉴别器的准确性,即它成功地将图像识别为真实图像或生成图像的百分比。

使用 PyTorch SDK,我似乎找不到一个类似的功能可以帮助我轻松地从我的模型中获取这个指标。

Pytorch 是否提供能够从模型中定义和提取通用指标的功能?

4

1 回答 1

0

Pure PyTorch提供开箱即用的指标,但您自己定义这些指标非常容易。

也没有“从模型中提取指标”之类的东西。度量是度量,它们测量(在这种情况下是鉴别器的准确性),它们不是模型固有的。

二进制精度

在您的情况下,您正在寻找二进制精度指标。下面的代码适用于logits(由 输出的非标准化概率discriminator,可能是nn.Linear没有激活的最后一层)或probabilities(最后nn.Linearsigmoid激活):

import typing
import torch


class BinaryAccuracy:
    def __init__(
        self,
        logits: bool = True,
        reduction: typing.Callable[
            [
                torch.Tensor,
            ],
            torch.Tensor,
        ] = torch.mean,
    ):
        self.logits = logits
        if logits:
            self.threshold = 0
        else:
            self.threshold = 0.5

        self.reduction = reduction

    def __call__(self, y_pred, y_true):
        return self.reduction(((y_pred > self.threshold) == y_true.bool()).float())

用法:

metric = BinaryAccuracy()
target = torch.randint(2, size=(64,))
outputs = torch.randn(size=(64, 1))

print(metric(outputs, target))

PyTorch Lightning 或其他第三方

您还可以在 PyTorch 之上使用PyTorch Lightning或其他框架,这些框架定义了准确度等指标

于 2021-02-25T10:51:51.130 回答