我需要并行训练两个模型。每个模型都有不同的激活函数和可训练的参数。我想训练模型一和模型二,使模型一(例如,alpha1)的激活函数的参数与模型二(例如,alpha2)中的参数间隔2;即,|alpha_1 - alpha_2| > 2. 我想知道如何将其包含在训练的损失函数中。
1 回答
示例模块定义
我将使用torch.nn.PReLU
您所说的参数激活。
get_weight
为方便而创建。
import torch
class Module(torch.nn.Module):
def __init__(self, in_features, out_features):
super().__init__()
self.input = torch.nn.Linear(in_features, 2 * in_features)
self.activation = torch.nn.PReLU()
self.output = torch.nn.Linear(2 * in_features, out_features)
def get_weight(self):
return self.activation.weight
def forward(self, inputs):
return self.output(self.activation(self.inputs(inputs)))
模块和设置
在这里,我使用一个优化器来优化您谈论的两个模块的参数。criterion
可以mean squared error
,cross entropy
或者你需要的任何其他东西。
module1 = Module(20, 1)
module2 = Module(20, 1)
optimizer = torch.optim.Adam(
itertools.chain(module1.parameters(), module2.parameters())
)
critertion = ...
训练
这是一个步骤,您应该像往常一样将其打包在数据上的 for 循环中,希望这足以让您明白这一点:
inputs = ...
targets = ...
output1 = module1(inputs)
output2 = module2(inputs)
loss1 = criterion(output1, targets)
loss2 = criterion(output2, targets)
total_loss = loss1 + loss2
total_loss += torch.nn.functional.relu(
2 - torch.abs(module1.get_weight() - module2.get_weight()).sum()
)
total_loss.backward()
optimizer.step()
在这种情况下,这一行就是您所追求的:
total_loss += torch.nn.functional.relu(
2 - torch.abs(module1.get_weight() - module2.get_weight()).sum()
)
relu
被使用,因此网络不会仅仅从创建不同的权重中获得无限的好处。如果没有,损失将变为负值,权重之间的差异越大。在这种情况下,差异越大越好,但在差距大于或等于 之后就没有区别了2
。
如果您必须通过阈值来优化价值,您可能需要增加2
或其他东西,因为当它接近时优化价值会很小。2.1
2
2.0
编辑
如果没有明确给出阈值,可能会很难,但也许这样的事情会起作用:
total_loss = (
(torch.abs(module1) + torch.abs(module2)).sum()
+ (1 / torch.abs(module1) + 1 / torch.abs(module2)).sum()
- torch.abs(module1 - module2).sum()
)
这对网络来说有点骇人听闻,但可能值得一试(如果您应用额外的L2
正则化)。
从本质上讲,这种损失将-inf, +inf
在相应位置的成对权重处具有最佳值,并且永远不会小于零。
对于那些重量
weights_a = torch.tensor([-1000.0, 1000, -1000, 1000, -1000])
weights_b = torch.tensor([1000.0, -1000, 1000, -1000, 1000])
每个部分的损失将是:
(torch.abs(module1) + torch.abs(module2)).sum() # 10000
(1 / torch.abs(module1) + 1 / torch.abs(module2)).sum() # 0.0100
torch.abs(module1 - module2).sum() # 10000
在这种情况下,网络可以通过在两个模块中使用相反的符号使权重更大而忽略您想要优化的内容(L2
两个模块的权重较大可能会有所帮助,我认为最佳值将是1
/-1
如果L2
'salpha
相等to 1
) 我怀疑网络可能非常不稳定。
有了这个损失函数,如果网络得到大权重错误的迹象,它将受到严重惩罚。
在这种情况下,您将需要L2
调整 alpha 参数以使其工作,这不是那么严格,但仍然需要选择超参数。