我正在重新实现可逆残差网络架构。
class iResNetBlock(nn.Module):
def __init__(self, input_size, hidden_size):
self.bottleneck = nn.Sequential(
LinearContraction(input_size, hidden_size),
LinearContraction(hidden_size, input_size),
nn.ReLU(),
)
def forward(self, x):
return x + self.bottleneck(x)
def inverse(self, y):
x = y.clone()
while not converged:
# fixed point iteration
x = y - self.bottleneck(x)
return x
我想为inverse
函数添加一个自定义的反向传递。由于它是定点迭代,因此可以利用隐函数定理来避免循环展开,而是通过求解线性系统来计算梯度。例如,这是在深度平衡模型架构中完成的。
def inverse(self, y):
with torch.no_grad():
x = y.clone()
while not converged:
# fixed point iteration
x = y - self.bottleneck(x)
return x
def custom_backward_inverse(self, grad_output):
pass
如何为此功能注册我的自定义反向通行证?我希望,当我稍后定义一些损失时 r = loss(y, model.inverse(other_model(model(x))))
,r.backwards()
正确地使用我的自定义渐变进行反向调用。
理想情况下,解决方案应该是torchscript
兼容的。