我在 MNIST 上运行以下代码
即,我从每个验证时期返回
return {"val_loss": loss, "recon_batch": recon_batch, "label_batch": label_batch, "label_img": orig_batch.view(-1, 1, 28, 28)}
然后使用
mat = torch.cat([o["recon_batch"] for o in outputs])
metadata = torch.cat([o["label_batch"] for o in outputs]).cpu()
label_img = torch.cat([o["label_img"] for o in outputs]).cpu()
tb.add_embedding(
mat=mat,
metadata=metadata,
label_img=label_img,
global_step=self.current_epoch,
)
并期望它能够工作,就像在doc中一样。
似乎只显示了一个批次,在验证过程中我得到的日志如下
验证:92%|█████████▏| 49/53 [00:01<00:00, 41.05it/s] 警告:嵌入目录存在,您是否为 add_embedding() 设置了 global_step?
recon_batch
如何获得适用于所有时代的 TSNE ?
完整代码供参考:
def validation_step(self, batch, batch_idx):
if self._config.dataset == "toy":
(orig_batch, noisy_batch), label_batch = batch
# TODO put in the noise here and not in the dataset?
elif self._config.dataset == "mnist":
orig_batch, label_batch = batch
orig_batch = orig_batch.reshape(-1, 28 * 28)
noisy_batch = orig_batch
else:
raise ValueError("invalid dataset")
noisy_batch = noisy_batch.view(noisy_batch.size(0), -1)
recon_batch, mu, logvar = self.forward(noisy_batch)
loss = self._loss_function(
recon_batch,
orig_batch, mu, logvar,
reconstruction_function=self._recon_function
)
tb = self.logger.experiment
tb.add_scalars("losses", {"val_loss": loss}, global_step=self.current_epoch)
if batch_idx == len(self.val_dataloader()) - 2:
orig_batch -= orig_batch.min()
orig_batch /= orig_batch.max()
recon_batch -= recon_batch.min()
recon_batch /= recon_batch.max()
orig_grid = torchvision.utils.make_grid(orig_batch.view(-1, 1, 28, 28))
val_recon_grid = torchvision.utils.make_grid(recon_batch.view(-1, 1, 28, 28))
tb.add_image("original_val", orig_grid, global_step=self.current_epoch)
tb.add_image("reconstruction_val", val_recon_grid, global_step=self.current_epoch)
# f, axarr = plt.subplots(2, 1)
# axarr[0].imshow(orig_grid.permute(1, 2, 0).cpu())
# axarr[1].imshow(val_recon_grid.permute(1, 2, 0).cpu())
# plt.show()
pass
return {"val_loss": loss, "recon_batch": recon_batch, "label_batch": label_batch,
"label_img": orig_batch.view(-1, 1, 28, 28)}
def validation_epoch_end(self, outputs: List[Any]) -> None:
first_batch_dict = outputs[-1]
self.log(name="val_epoch_end", value={"val_loss": first_batch_dict["val_loss"]})
tb = self.logger.experiment
# assert mat.shape[0] == label_img.shape[0], '#images should equal with #data points'
mat = torch.cat([o["recon_batch"] for o in outputs])
metadata = torch.cat([o["label_batch"] for o in outputs]).cpu()
label_img = torch.cat([o["label_img"] for o in outputs]).cpu()
tb.add_embedding(
mat=mat,
metadata=metadata,
label_img=label_img,
global_step=self.current_epoch,
)