torch.nn
有类BatchNorm1d
, BatchNorm2d
, BatchNorm3d
,但它没有完全连接的 BatchNorm 类?在 PyTorch 中执行正常 Batch Norm 的标准方法是什么?
问问题
20545 次
2 回答
26
好的。我想到了。BatchNorm1d
也可以处理 Rank-2 张量,因此可以BatchNorm1d
用于正常的全连接情况。
例如:
import torch.nn as nn
class Policy(nn.Module):
def __init__(self, num_inputs, action_space, hidden_size1=256, hidden_size2=128):
super(Policy, self).__init__()
self.action_space = action_space
num_outputs = action_space
self.linear1 = nn.Linear(num_inputs, hidden_size1)
self.linear2 = nn.Linear(hidden_size1, hidden_size2)
self.linear3 = nn.Linear(hidden_size2, num_outputs)
self.bn1 = nn.BatchNorm1d(hidden_size1)
self.bn2 = nn.BatchNorm1d(hidden_size2)
def forward(self, inputs):
x = inputs
x = self.bn1(F.relu(self.linear1(x)))
x = self.bn2(F.relu(self.linear2(x)))
out = self.linear3(x)
return out
于 2017-11-09T12:44:44.690 回答
8
BatchNorm1d 通常出现在 ReLU 之前,并且偏差是多余的,所以
import torch.nn as nn
class Policy(nn.Module):
def __init__(self, num_inputs, action_space, hidden_size1=256, hidden_size2=128):
super(Policy2, self).__init__()
self.action_space = action_space
num_outputs = action_space
self.linear1 = nn.Linear(num_inputs, hidden_size1, bias=False)
self.linear2 = nn.Linear(hidden_size1, hidden_size2, bias=False)
self.linear3 = nn.Linear(hidden_size2, num_outputs)
self.bn1 = nn.BatchNorm1d(hidden_size1)
self.bn2 = nn.BatchNorm1d(hidden_size2)
def forward(self, inputs):
x = inputs
x = F.relu(self.bn1(self.linear1(x)))
x = F.relu(self.bn2(self.linear2(x)))
out = self.linear3(x)
return out
于 2020-01-14T15:08:37.787 回答