我想修改示例 DGL GATLayer,以便网络可以学习边缘权重,而不是学习节点表示。也就是说,我想构建一个以一组节点特征作为输入并输出边的网络。标签将是一组“真实边缘”,它们表示哪些节点来自一个共同的来源,这样我就可以学习以相同的方式对看不见的数据进行聚类。
我使用以下 DGL 示例中的代码作为起点:
https://www.dgl.ai/blog/2019/02/17/gat.html
import torch.nn as nn
import torch.nn.functional as F
class GATLayer(nn.Module):
def __init__(self, g, in_dim, out_dim):
super(GATLayer, self).__init__()
self.g = g
# equation (1)
self.fc = nn.Linear(in_dim, out_dim, bias=False)
# equation (2)
self.attn_fc = nn.Linear(2 * out_dim, 1, bias=False)
def edge_attention(self, edges):
# edge UDF for equation (2)
z2 = torch.cat([edges.src['z'], edges.dst['z']], dim=1)
a = self.attn_fc(z2)
return {'e' : F.leaky_relu(a)}
def message_func(self, edges):
# message UDF for equation (3) & (4)
return {'z' : edges.src['z'], 'e' : edges.data['e']}
def reduce_func(self, nodes):
# reduce UDF for equation (3) & (4)
# equation (3)
alpha = F.softmax(nodes.mailbox['e'], dim=1)
# equation (4)
h = torch.sum(alpha * nodes.mailbox['z'], dim=1)
return {'h' : h}
def forward(self, h):
# equation (1)
z = self.fc(h)
self.g.ndata['z'] = z
# equation (2)
self.g.apply_edges(self.edge_attention)
# equation (3) & (4)
self.g.update_all(self.message_func, self.reduce_func)
return self.g.ndata.pop('h')
class MultiHeadGATLayer(nn.Module):
def __init__(self, g, in_dim, out_dim, num_heads, merge='cat'):
super(MultiHeadGATLayer, self).__init__()
self.heads = nn.ModuleList()
for i in range(num_heads):
self.heads.append(GATLayer(g, in_dim, out_dim))
self.merge = merge
def forward(self, h):
head_outs = [attn_head(h) for attn_head in self.heads]
if self.merge == 'cat':
# concat on the output feature dimension (dim=1)
return torch.cat(head_outs, dim=1)
else:
# merge using average
return torch.mean(torch.stack(head_outs))
class GAT(nn.Module):
def __init__(self, g, in_dim, hidden_dim, out_dim, num_heads):
super(GAT, self).__init__()
self.layer1 = MultiHeadGATLayer(g, in_dim, hidden_dim, num_heads)
# Be aware that the input dimension is hidden_dim*num_heads since
# multiple head outputs are concatenated together. Also, only
# one attention head in the output layer.
self.layer2 = MultiHeadGATLayer(g, hidden_dim * num_heads, out_dim, 1)
def forward(self, h):
h = self.layer1(h)
h = F.elu(h)
h = self.layer2(h)
return h
我曾希望我可以调整它以简单地返回边缘而不是节点,例如通过替换线
return self.g.ndata.pop('h')
和
return self.e.ndata.pop('e')
但似乎事情并没有这么简单。我设法让一些东西跑起来,但是损失到处乱窜,没有任何学习发生。
我是图形网络的新手,但一般来说不是深度学习。我想做的事是合理的吗?我是否遗漏了一些对我理解它是如何工作的至关重要的东西?我一直找不到任何易于理解的图网络示例,其中边缘本身就是学习目标,所以我现在有点糊涂。我感谢任何人可以提供的任何帮助!