我在 python 中使用 dgl 实现图形注意网络代码:
我想在实现的等式(2)中使用 min、max、sum 或 Hadamard 函数而不是串联函数。我怎样才能做到这一点?
等式(2)的代码如下:
class GATLayer(nn.Module):
def __init__(self, g, in_dim, out_dim):
super(GATLayer, self).__init__()
self.g = g
# equation (1)
self.fc = nn.Linear(in_dim, out_dim, bias=False)
# equation (2)
self.attn_fc = nn.Linear(2 * out_dim, 1, bias=False)
self.reset_parameters()
def reset_parameters(self):
gain = nn.init.calculate_gain('relu')
nn.init.xavier_normal_(self.fc.weight, gain=gain)
nn.init.xavier_normal_(self.attn_fc.weight, gain=gain)
def edge_attention(self, edges):
# edge UDF for equation (2)
z2 = torch.cat([edges.src['z'], edges.dst['z']], dim=1)
a = self.attn_fc(z2)
return {'e': F.leaky_relu(a)}
def message_func(self, edges):
# message UDF for equation (3) & (4)
return {'z': edges.src['z'], 'e': edges.data['e']}
连接函数torch.cat()
在下面的代码中:
class MultiHeadGATLayer(nn.Module):
def __init__(self, g, in_dim, out_dim, num_heads, merge='cat'):
super(MultiHeadGATLayer, self).__init__()
self.heads = nn.ModuleList()
for i in range(num_heads):
self.heads.append(GATLayer(g, in_dim, out_dim))
self.merge = merge
def forward(self, h):
head_outs = [attn_head(h) for attn_head in self.heads]
if self.merge == 'cat':
# concat on the output feature dimension (dim=1)
#return torch.min(torch.stack(head_outs), dim=1)
return torch.cat(head_outs, dim=1)
else:
# merge using average
return torch.mean(torch.stack(head_outs))
如何更改原始实现的串联功能?