2

当使用该模块加载预训练的 VGG 网络torchvision.models并使用它对任意 RGB 图像进行分类时,网络的输出因调用而异。为什么会这样?据我了解,VGG 前向传递的任何部分都不应该是不确定的。

这是一个 MCVE:

import torch
from torchvision.models import vgg16

vgg = vgg16(pretrained=True)

img = torch.randn(1, 3, 256, 256)

torch.all(torch.eq(vgg(img), vgg(img))) # result is 0, but why?
4

1 回答 1

2

vgg16有一个nn.Dropout层,在训练期间会随机丢弃 50% 的输入。在测试期间,您应该通过将网络模式设置为“评估”模式来“关闭”这种行为:

vgg.eval()
torch.all(torch.eq(vgg(img), vgg(img)))
Out[73]: tensor(1, dtype=torch.uint8)

请注意,还有其他层具有随机行为和用于训练和评估的不同行为(例如,BatchNorm)。因此,在评估经过训练的模型之前切换到eval()模式很重要。

于 2019-05-06T10:44:04.883 回答