1

我正在尝试使用将我的 PyG 图转换为 NetworkX 图to_networkx

根据文档,除了 Data 对象之外,我还可以选择将节点和边缘属性作为 str 可迭代对象传递。

以下是节点和边缘属性列表,其值转换为字符串:

Nodes:  ['3.3375725746154785', '2.0086510181427',..., '1.5960148572921753', '3.621992349624634']

Edges:  ['0.9940207804344958', '0.48573804411542043', ..., '0.7245483440145621', '0.24117984598949904']

to_networkx当我只将 Data 对象传递给它时运行良好。但是,当我也传递这些属性列表时,我收到以下错误:

G[u][v][key] = values[key][i]
KeyError: '0.30194718370332896'

我查看了源代码,但无法弄清楚它在做什么。有人可以帮助解释我的属性列表有什么问题以及我需要更改哪些内容才能被接受。

我可以看出,这个错误专门指的是我的边缘属性。如果我删除它们,我会收到以下与节点属性相关的类似错误:

feat_dict.update({key: values[key][i]})
KeyError: '0.0'

我如何构建我的图表并将其传递给to_networkx

n1 = np.repeat(np.array([0,1,2,3,4,5,6]),5)
n2 = np.array([0,1,2,3,4,0,1,2,3,4,0,1,2,3,4,0,1,2,3,4,0,1,2,3,4,0,1,2,3,4,0,1,2,3,4])
cat = np.stack((n1,n2), axis=0)
e = torch.tensor(cat, dtype=torch.long)
edge_index = e.t().clone().detach()
edge_attr = torch.tensor(np.random.rand(35,1))

x = torch.tensor([[0], [0], [0], [0], [0], [1], [1]], dtype=torch.float)

data = Data(x=x, edge_index=edge_index.t().contiguous(), edge_attr = edge_attr)

在传递节点和边属性之前,我会进行字符串转换以符合 str 可迭代要求:

networkx_node_values = list(map(str, data.x.t()[0].tolist()))
networkx_edge_values = list(map(str, edge_attr.t()[0].tolist()))
    
networkX_graph = to_networkx(data, node_attrs = networkx_node_values, edge_attrs = networkx_edge_values)
4

1 回答 1

2

您需要将属性名称作为列表传递:

to_networkx(<PyTorchGeometricDataObject>, node_attrs=[<Name of Node Attribute 1>, <Name of Node Attributes 2>, ... ], edge_attr=[<Edge Attribute 1>, ...])

或者在上下文中,根据您给定的最小示例:

import numpy as np
import torch
from torch_geometric.data import Data
from torch_geometric.utils import to_networkx

n1 = np.repeat(np.array([0,1,2,3,4,5,6]),5)
n2 = np.array([0,1,2,3,4,0,1,2,3,4,0,1,2,3,4,0,1,2,3,4,0,1,2,3,4,0,1,2,3,4,0,1,2,3,4])
cat = np.stack((n1,n2), axis=0)
e = torch.tensor(cat, dtype=torch.long)
edge_index = e.t().clone().detach()
edge_attr = torch.tensor(np.random.rand(35,1))

x = torch.tensor([[0], [0], [0], [0], [0], [1], [1]], dtype=torch.float)

data = Data(x=x, edge_index=edge_index.t().contiguous(), edge_attr = edge_attr)
print(data)
# Data(edge_attr=[35, 1], edge_index=[2, 35], x=[7, 1])

networkX_graph = to_networkx(data, node_attrs=["x"], edge_attrs=["edge_attr"])

print(networkX_graph.nodes(data=True))
# [(0, {'x': 0.0}), (1, {'x': 0.0}),...
print(networkX_graph.edges(data=True))
# [(0, 0, {'edge_attr': 0.3412137594357493}), ...
于 2022-02-08T20:21:04.110 回答