0

我正在尝试在其中使用图形注意网络(GAT)模块,torch_geometric但继续AssertionError: Static graphs not supported in 'GATConv'使用以下代码。

class GraphConv_sum(nn.Module):
    def __init__(self, in_ch, out_ch, num_layers, block, adj):
        super(GraphConv_sum, self).__init__()
        adj_coo = coo_matrix(adj) # convert the adjacency matrix to COO format for Pytorch Geometric
        self.edge_index = torch.tensor([adj_coo.row, adj_coo.col], dtype=torch.long)
        self.g_conv = nn.ModuleList()
        
        self.act = nn.LeakyReLU()

        for n in range(num_layers):
            if n == 0:
                self.g_conv.append(block(in_ch, 16))
            elif n > 0 and n < num_layers - 1:
                self.g_conv.append(block(16, 16))
            else:
                self.g_conv.append(block(16, out_ch))

    def forward(self, x):
        for layer in self.g_conv:
            x = layer(x=x, edge_index=self.edge_index)
            x = self.act(x)
            print(x.shape)
        return x[:, 0, :]

当我替换blockGATConv后跟标准训练循环时,会发生此错误(其他卷积层,例如GCNConvSAGEConv没有任何问题)。我检查了文档并确保输入形状正确(其他卷积层相同)。

在源代码assert x.dim() == 2, "Static graphs not supported in 'GATConv'"中,该方法中有这一部分,forward但显然批处理维度将在前向传递中发挥作用,并且x.dim()将为 3。批处理维度的输入形状为 [1024,6,200]。但是,如果我手动将断言条件更改x.dim() == 3为相同的错误,仍然会引发好像条件不满足一样。我只对 GAT 有较高的了解,所以我可能遗漏了一些东西。无论如何,我对此有几个问题

  • 我这边是否存在任何可能导致此错误的实施错误?
  • 这个断言条件是干什么用的?在这种情况下,静态图是什么?

我将不胜感激任何见解和帮助!谢谢!

4

1 回答 1

0

事实证明,由于注意力权重计算,GATConv 不支持多个特征矩阵和单个 edge_index。更多信息:https ://github.com/pyg-team/pytorch_geometric/issues/2844

于 2022-02-08T18:35:47.033 回答