0

我正在研究 GNN,并编写代码:

Pytorch Geometric教程中的Pytorch Geometric介绍代码

import torch_geometric
from torch_geometric.datasets import Planetoid
dataset = Planetoid(root="tutorial1",name= "Cora")
data = dataset[0]
print(data)
Data(x=[2708, 1433], edge_index=[2, 10556], y=[2708], train_mask=[2708], val_mask=[2708], test_mask=[2708])
##############(I omitted my neural network and train(), which are not related to my question)########

def test():
    model.eval()
    logits, accs = model(), []
    for _, mask in data('train_mask', 'val_mask', 'test_mask'):
        pred = logits[mask].max(1)[1]
        acc = pred.eq(data.y[mask]).sum().item() / mask.sum().item()
        accs.append(acc)
    return accs

我很好奇的是

  for_, mask in data('train_mask', 'val_mask', 'test_mask):

因为我不明白 data('train_mask', 'val_mask', 'test_mask) 是什么。结果是

<generator object Data.__call__ at 0x7f617c8498d0>

所以我不明白它是什么。我阅读了一些生成器的文档,但是我怎样才能看到这些元素是什么?

4

1 回答 1

0

data您从数据集中检索的对象Planetoid是单个图形。您具有以下属性:

  • x节点特征,因此它的维度是节点数(2703)乘以特征维度(1433)
  • edge_index边缘列表
  • y“基本事实”/类别标签或在特定情况下是论文的分类。因此,它的形状是节点的数量。
  • 三个面具:train_mask, val_mask, test_mask. 如果我通过 访问它们data.train_mask,它会给我一个长度 = 节点数的布尔张量。这是数据集的“默认拆分”。它们应该是不相交的,并且如果True相应的节点在该集合中。
于 2022-01-03T12:27:14.470 回答