0

我目前正在尝试拟合一个非常简单的模型,该模型基本上应该为直方图匹配找到最佳直方图。我写了一个超级简单的模型,只有一个我直接使用的 Parameter 对象:

import torch.nn.functional as F

class AutoHist(pl.LightningModule):    
    def __init__(self, channel=1, bins=255):
        super().__init__()
        self.hist = torch.nn.Parameter(torch.rand((1, channel, bins), requires_grad=True))
        self.eps = 1e-5    

    def b_distance(self, h1, h2):
        distance = 1
        distance -= 1/(torch.sqrt(torch.mean(h1, axis=2)*torch.mean(h2, axis=2)*h1.size(2)**2))
        distance *= torch.sum(torch.sqrt(h1*h2 + self.eps),axis=2)
        return torch.sqrt(distance + self.eps)    

    def training_step(self, batch, batch_idx):
        # training_step defined the train loop.
        # It is independent of forward
        x, y = batch
        hist = self.hist / self.hist.sum()
        distances = self.b_distance(hist,x)
        loss = F.binary_cross_entropy(distances[:,0], y)
        return loss    

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
        return optimizer

但是由于某种原因,反向传播没有通过并且参数没有得到更新。有人知道问题可能出在哪里吗?梯度实际上是存在的,并且确实会随着批次的变化而变化。我使用 pytorch 闪电删除样板代码,但症结应该在于我编写的代码。

4

0 回答 0