这是模型结构。为了方便超参数调整,我将几乎所有内容组合在一起。''' 类 MultiTaskDNN(nn.Module):
def __init__(self, n_tasks,
input_dim=1024,
output_dim=1,
hidden_dim=[1024, 100],
inits=['xavier_normal', 'kaiming_uniform'],
act_function=['relu', 'leaky_relu'],
dropouts=[0.10, 0.25],
batch_norm=True):
super(MultiTaskDNN, self).__init__()
self.n_tasks = n_tasks
self.input_dim = input_dim
self.output_dim = output_dim
self.hidden_dim = hidden_dim
self.act_function = act_function
self.batch_norm = batch_norm
current_dim = input_dim
self.layers = nn.ModuleList()
self.dropouts = nn.ModuleList()
self.bns = nn.ModuleList()
for k, hdim in enumerate(hidden_dim):
self.layers.append(nn.Linear(current_dim, hdim))
self.bns.append(nn.BatchNorm1d(hdim, eps=2e-1))
current_dim = hdim
if inits[k] == 'xavier_normal':
nn.init.xavier_normal_(self.layers[k].weight)
elif inits[k] == 'kaiming_normal':
nn.init.kaiming_normal_(self.layers[k].weight)
elif inits[k] == 'xavier_uniform':
nn.init.xavier_uniform_(self.layers[k].weight)
elif inits[k] == 'kaiming_uniform':
nn.init.kaiming_uniform_(self.layers[k].weight)
self.dropouts.append(nn.Dropout(dropouts[k]))
# n_targets
self.heads = nn.ModuleList()
for _ in range(self.n_tasks):
self.heads.append(nn.Linear(current_dim, output_dim))
def forward(self, x):
for k, layer in enumerate(self.layers):
x = layer(x)
if self.act_function[k] == 'sigmoid':
x = torch.sigmoid(x)
elif self.act_function[k] == 'relu':
x = F.relu(x)
elif self.act_function[k] == 'leaky_relu':
x = F.leaky_relu(x)
if self.batch_norm == True:
x = self.bns[k](x)
x = self.dropouts[k](x)
outputs = []
for head in self.heads:
outputs.append(head(x))
return outputs
'''
如果结构看起来正确,也请告诉我。在训练这个具有 10 个任务(头)的多任务模型之后。我只想预测任务 7,即头 7。我应该如何加载模型并进行预测?谢谢你。
model.state_dict()
MultiTaskDNN(
(layers): ModuleList(
(0): Linear(in_features=1024, out_features=128, bias=True)
(1): Linear(in_features=128, out_features=128, bias=True)
)
(dropouts): ModuleList(
(0): Dropout(p=0.25, inplace=False)
(1): Dropout(p=0.25, inplace=False)
)
(bns): ModuleList(
(0): BatchNorm1d(128, eps=0.2, momentum=0.1, affine=True, track_running_stats=True)
(1): BatchNorm1d(128, eps=0.2, momentum=0.1, affine=True, track_running_stats=True)
)
(heads): ModuleList(
(0): Linear(in_features=128, out_features=1, bias=True)
(1): Linear(in_features=128, out_features=1, bias=True)
(2): Linear(in_features=128, out_features=1, bias=True)
(3): Linear(in_features=128, out_features=1, bias=True)
(4): Linear(in_features=128, out_features=1, bias=True)
(5): Linear(in_features=128, out_features=1, bias=True)
(6): Linear(in_features=128, out_features=1, bias=True)
(7): Linear(in_features=128, out_features=1, bias=True)
(8): Linear(in_features=128, out_features=1, bias=True)
(9): Linear(in_features=128, out_features=1, bias=True)
(10): Linear(in_features=128, out_features=1, bias=True)
)
)