1

我在 MNIST 上运行以下代码

即,我从每个验证时期返回

return {"val_loss": loss, "recon_batch": recon_batch, "label_batch": label_batch, "label_img": orig_batch.view(-1, 1, 28, 28)}

然后使用

    mat = torch.cat([o["recon_batch"] for o in outputs])
    metadata = torch.cat([o["label_batch"] for o in outputs]).cpu()
    label_img = torch.cat([o["label_img"] for o in outputs]).cpu()
    tb.add_embedding(
        mat=mat,
        metadata=metadata,
        label_img=label_img,
        global_step=self.current_epoch,
    )

并期望它能够工作,就像在doc中一样。

似乎只显示了一个批次,在验证过程中我得到的日志如下

验证:92%|█████████▏| 49/53 [00:01<00:00, 41.05it/s] 警告:嵌入目录存在,您是否为 add_embedding() 设置了 global_step?

recon_batch如何获得适用于所有时代的 TSNE ?


完整代码供参考:

def validation_step(self, batch, batch_idx):
    if self._config.dataset == "toy":
        (orig_batch, noisy_batch), label_batch = batch
        # TODO put in the noise here and not in the dataset?
    elif self._config.dataset == "mnist":
        orig_batch, label_batch = batch
        orig_batch = orig_batch.reshape(-1, 28 * 28)
        noisy_batch = orig_batch
    else:
        raise ValueError("invalid dataset")

    noisy_batch = noisy_batch.view(noisy_batch.size(0), -1)

    recon_batch, mu, logvar = self.forward(noisy_batch)

    loss = self._loss_function(
        recon_batch,
        orig_batch, mu, logvar,
        reconstruction_function=self._recon_function
    )

    tb = self.logger.experiment
    tb.add_scalars("losses", {"val_loss": loss}, global_step=self.current_epoch)
    if batch_idx == len(self.val_dataloader()) - 2:
        orig_batch -= orig_batch.min()
        orig_batch /= orig_batch.max()
        recon_batch -= recon_batch.min()
        recon_batch /= recon_batch.max()

        orig_grid = torchvision.utils.make_grid(orig_batch.view(-1, 1, 28, 28))
        val_recon_grid = torchvision.utils.make_grid(recon_batch.view(-1, 1, 28, 28))

        tb.add_image("original_val", orig_grid, global_step=self.current_epoch)
        tb.add_image("reconstruction_val", val_recon_grid, global_step=self.current_epoch)
        # f, axarr = plt.subplots(2, 1)
        # axarr[0].imshow(orig_grid.permute(1, 2, 0).cpu())
        # axarr[1].imshow(val_recon_grid.permute(1, 2, 0).cpu())
        # plt.show()
        pass

    return {"val_loss": loss, "recon_batch": recon_batch, "label_batch": label_batch,
            "label_img": orig_batch.view(-1, 1, 28, 28)}

def validation_epoch_end(self, outputs: List[Any]) -> None:
    first_batch_dict = outputs[-1]
    self.log(name="val_epoch_end", value={"val_loss": first_batch_dict["val_loss"]})

    tb = self.logger.experiment
    # assert mat.shape[0] == label_img.shape[0], '#images should equal with #data points'
    mat = torch.cat([o["recon_batch"] for o in outputs])
    metadata = torch.cat([o["label_batch"] for o in outputs]).cpu()
    label_img = torch.cat([o["label_img"] for o in outputs]).cpu()
    tb.add_embedding(
        mat=mat,
        metadata=metadata,
        label_img=label_img,
        global_step=self.current_epoch,
    )
4

0 回答 0