我正在研究 GNN,并编写代码:
Pytorch Geometric教程中的Pytorch Geometric介绍代码
import torch_geometric
from torch_geometric.datasets import Planetoid
dataset = Planetoid(root="tutorial1",name= "Cora")
data = dataset[0]
print(data)
Data(x=[2708, 1433], edge_index=[2, 10556], y=[2708], train_mask=[2708], val_mask=[2708], test_mask=[2708])
##############(I omitted my neural network and train(), which are not related to my question)########
def test():
model.eval()
logits, accs = model(), []
for _, mask in data('train_mask', 'val_mask', 'test_mask'):
pred = logits[mask].max(1)[1]
acc = pred.eq(data.y[mask]).sum().item() / mask.sum().item()
accs.append(acc)
return accs
我很好奇的是
for_, mask in data('train_mask', 'val_mask', 'test_mask):
因为我不明白 data('train_mask', 'val_mask', 'test_mask) 是什么。结果是
<generator object Data.__call__ at 0x7f617c8498d0>
所以我不明白它是什么。我阅读了一些生成器的文档,但是我怎样才能看到这些元素是什么?