0

我需要你的帮助。我有两组图形结构化数据,一组来自 Open Graph Benchmark (OGB),另一组torch_geometric.data.Dataset来自我自己的数据。数据如下:

数据(edge_index=[2, 88], edge_attr=[88, 3], x=[39, 9], y=[1, 1]) #OGB

数据(x=[23, 9], edge_index=[2, 48], edge_attr=[48, 2], y=[1]) #PyG

我正在尝试使用使用 OGB 函数开发的框架,这不适用于使用 PyG 创建的数据。例如:框架的第一部分加载并将数据集拆分为训练、验证和测试:

# Set the random seed
random.seed(random_seed)
np.random.seed(random_seed)

# Create data loaders
split_idx = dataset.get_idx_split() # train/val/test split
loader_dict = {}
for phase in split_idx:
    batch_size = 32
    loader_dict[phase] = DataLoader(dataset[split_idx[phase]], batch_size=batch_size, shuffle=False)

当我使用本机 ogb 数据集运行此代码时,我没有问题,当我使用 PyG 数据时返回错误:

属性错误

这很奇怪,因为它们都是 Pytorch 对象,唯一的区别是 OGB 数据集是 InMemoryDataset 而 PyG 是“更大”数据集(https://pytorch-geometric.readthedocs.io/en/latest/notes/ create_dataset.html)。有什么方法可以解决这个问题而不必更改源代码?

谢谢!

4

1 回答 1

1

如果要使用相同的代码,则需要get_idx_split为自己的数据集实现。您可以在OGB GitHub 中找到所需的返回结构,例如:

def get_idx_split(self):
    < ... do something to retrieve train/test/validation set>
    return {'train': train_idx, 'valid': valid_idx, 'test': test_idx}
于 2022-02-15T12:43:57.413 回答