我正在尝试在其中使用图形注意网络(GAT)模块,torch_geometric
但继续AssertionError: Static graphs not supported in 'GATConv'
使用以下代码。
class GraphConv_sum(nn.Module):
def __init__(self, in_ch, out_ch, num_layers, block, adj):
super(GraphConv_sum, self).__init__()
adj_coo = coo_matrix(adj) # convert the adjacency matrix to COO format for Pytorch Geometric
self.edge_index = torch.tensor([adj_coo.row, adj_coo.col], dtype=torch.long)
self.g_conv = nn.ModuleList()
self.act = nn.LeakyReLU()
for n in range(num_layers):
if n == 0:
self.g_conv.append(block(in_ch, 16))
elif n > 0 and n < num_layers - 1:
self.g_conv.append(block(16, 16))
else:
self.g_conv.append(block(16, out_ch))
def forward(self, x):
for layer in self.g_conv:
x = layer(x=x, edge_index=self.edge_index)
x = self.act(x)
print(x.shape)
return x[:, 0, :]
当我替换block
为GATConv
后跟标准训练循环时,会发生此错误(其他卷积层,例如GCNConv
或SAGEConv
没有任何问题)。我检查了文档并确保输入形状正确(其他卷积层相同)。
在源代码assert x.dim() == 2, "Static graphs not supported in 'GATConv'"
中,该方法中有这一部分,forward
但显然批处理维度将在前向传递中发挥作用,并且x.dim()
将为 3。批处理维度的输入形状为 [1024,6,200]。但是,如果我手动将断言条件更改x.dim() == 3
为相同的错误,仍然会引发好像条件不满足一样。我只对 GAT 有较高的了解,所以我可能遗漏了一些东西。无论如何,我对此有几个问题
- 我这边是否存在任何可能导致此错误的实施错误?
- 这个断言条件是干什么用的?在这种情况下,静态图是什么?
我将不胜感激任何见解和帮助!谢谢!