我正在尝试计算一个简单神经网络的输出相对于输入的梯度。当我不使用 BatchNorm 图层时,结果看起来不错。一旦我使用它,结果似乎没有多大意义。下面是一个重现效果的简短示例。
class Net(nn.Module):
def __init__(self, batch_norm):
super().__init__()
self.batch_norm = batch_norm
self.act_fn = nn.Tanh()
self.aff1 = nn.Linear(1, 10)
self.aff2 = nn.Linear(10, 1)
if batch_norm:
self.bn = nn.BatchNorm1d(10, affine=False) # False for simplicity
def forward(self, x):
x = self.aff1(x)
x = self.act_fn(x)
if self.batch_norm:
x = self.bn(x)
x = self.aff2(x)
return x
x_vals = torch.linspace(0, 1, 100)
x_vals.requires_grad = True
fig, axs = plt.subplots(ncols=2, figsize=(16, 5))
for seed, bn, ax1 in zip([11, 4], [False, True], axs): # different seeds for better illustration of effect
torch.manual_seed(seed)
net = Net(batch_norm=bn)
net.train()
pred = net(x_vals[:, None])
pred_dx = torch.autograd.grad(pred.sum(), x_vals, create_graph=True)[0]
# visualization
ax2 = ax1.twinx()
ax1.plot(x_vals.detach(), pred.detach())
ax2.plot(x_vals.detach(), pred_dx.detach(), linestyle='--', color='orange')
min_idx = torch.argmin((pred[1:]-pred[:-1])**2)
ax2.axvline(x_vals[min_idx].detach(), color='gray', linestyle='dotted')
ax2.axhline(0, color='gray', linestyle='dotted')
ax1.set_title(('With' if bn else 'Without') + ' Batch Norm')
plt.show()
当我使用评估模式时,结果似乎也很好。不幸的是,我不能只切换到 eval() 模式,因为我的问题(PINN)的性质需要在训练期间计算梯度。
我了解在训练期间会更新运行均值和方差。也许这有影响?我还能以某种方式获得正确的渐变吗?
谢谢你的帮助!