关键是 2D batchnorm 对每个通道执行相同的归一化。即,如果您有一批形状为(N、C、H、W)的数据,那么您的 mu 和 stddev 应该是形状(C,)。如果您的图像没有通道尺寸,请使用view
.
警告:如果您设置training=True
thenbatch_norm
计算并使用争论批次的适当归一化统计数据(这意味着我们不需要自己计算均值和标准差)。您争论的 mu 和 stddev 应该是所有训练批次的运行平均值和运行 std。batch_norm
这些张量使用函数中的新批次统计信息进行更新。
# inp is shape (N, C, H, W)
n_chans = inp.shape[1]
running_mu = torch.zeros(n_chans) # zeros are fine for first training iter
running_std = torch.ones(n_chans) # ones are fine for first training iter
x = nn.functional.batch_norm(inp, running_mu, running_std, training=True, momentum=0.9)
# running_mu and running_std now have new values
如果您只想使用自己的批处理统计信息,请尝试以下操作:
# inp is shape (N, C, H, W)
n_chans = inp.shape[1]
reshaped_inp = inp.permute(1,0,2,3).contiguous().view(n_chans, -1) # shape (C, N*W*H)
mu = reshaped_inp.mean(-1)
stddev = reshaped_inp.std(-1)
x = nn.functional.batch_norm(inp, mu, stddev, training=False)