我试图用预测来可视化实际图像,以弄清楚我的算法是如何执行的,以及哪些标签被错误地预测了。但是当我设置步骤时,在张量板上的可视化过程中,它并没有显示所有步骤。因此,它不会显示所有训练图像及其标签。相反,我只能从所有训练图像中看到几个示例。
writer = SummaryWriter(log_dir='graphs')
def matplotlib_imshow(img):
npimg = img.cpu().numpy()
npimg = np.transpose(npimg, (1, 2, 0))
plt.imshow((npimg * 255).astype(np.uint8))
def images_to_probs(net, images):
output = net(images)
_, preds_tensor = torch.max(output, 1)
preds = np.squeeze(preds_tensor.cpu().numpy())
return preds, [F.softmax(el, dim=0)[i].item() for i, el in zip(preds, output)]
def plot_classes_preds(net, images, labels):
preds, probs = images_to_probs(net, images)
fig = plt.figure(figsize=(6, 6))
for idx in np.arange(4):
ax = fig.add_subplot(1, 4, idx+1, xticks=[], yticks=[])
matplotlib_imshow(images[idx])
ax.set_title("{0}, {1:.1f}%\n(label: {2})".format(
classes[preds[idx]],
probs[idx] * 100.0,
classes[labels[idx]]),
color=("green" if preds[idx]==labels[idx].item() else "red"))
return fig
以下是我使用全局步骤作为步骤的训练循环。
for epoch in range(epochs):
epoch_start_time = time.time()
losses = []
total_batch_images = 0
batch_correct_pred = 0
step = 0
#save model
# if batch_accuracy>best_acc:
# best_acc = batch_accuracy
# checkpoint = {'state_dict': model.state_dict(),'acc' : batch_accuracy, 'epoch' : epoch, 'optimizer': optimizer.state_dict()}
# save_checkpoint(checkpoint)
model.train()
for batch_idx, (images, labels) in enumerate(train_loader):
# Get data to cuda if possible
images = images.to(device=device)
labels = labels.to(device=device)
# forward
scores = model(images)
loss = criterion(scores, labels)
losses.append(loss.item())
# backward
optimizer.zero_grad()
loss.backward()
# gradient descent or adam step
optimizer.step()
# visualizing Dataset images
# img_grid = torchvision.utils.make_grid(images)
# writer.add_image('Xray_images', img_grid, global_step = step)
# calculation running accuracy
model.eval()
_, predictions = scores.max(1)
num_correct = (predictions == labels).sum()
batch_correct_pred += float(num_correct)
total_batch_images += predictions.size(0)
writer.add_figure('predictions vs. actuals',
plot_classes_preds(model, images, labels),
global_step=step)
step += 1
我认为问题出在 writer.add_figure() 定义 global_step 的最后一行。但是,在这方面的任何帮助将不胜感激。