Pytorch 初学者在这里!考虑以下自定义模块:
class Testme(nn.Module):
def __init__(self):
super(Testme, self).__init__()
def forward(self, x):
return x / t_.max(x).expand_as(x)
据我了解文档:我相信这也可以作为自定义实现Function
。的子类Function
需要backward()
方法,但
Module
不需要。同样,在 Linear 的文档示例中Module
,它取决于 Linear Function
:
class Linear(nn.Module):
def __init__(self, input_features, output_features, bias=True):
...
def forward(self, input):
return Linear()(input, self.weight, self.bias)
问题:我不明白 和 之间的Module
关系Function
。在上面的第一个清单(模块Testme
)中,它应该有关联的功能吗?如果不是,那么可以通过子类化 Module 来实现这一点而无需backward
方法,那么为什么Function
总是需要backward
方法呢?
也许Function
s 仅适用于不是由现有的火炬功能组成的功能?Function
换一种说法:如果模块的forward
方法完全由先前定义的 Torch 函数组成,也许模块不需要关联?