0

我正在研究 Pytorch-Geometric 文档(此处)。

在下面的代码中,我们看到data没有train_mask. 但是,当将输出和标签传递给损失函数时,train_mask两者都适用。在将其输入模型时,我们不应该也应用train_maskto吗?data在我看来,这应该不是问题。然而,看起来我们在不用于训练模型的输出上浪费了计算。

model.train()
for epoch in range(200):
    optimizer.zero_grad()
    out = model(data)
    loss = F.nll_loss(out[data.train_mask], data.y[data.train_mask])
    loss.backward()
    optimizer.step()
4

2 回答 2

1

我认为主要的答案是它使代码更加复杂。

你不能简单地说out = model(data[data.train_mask])。(这将引发异常,因为Dataobject 不支持以这种方式切片)。

您可以拥有单独的对象train_data test_data等,但这会使代码更加复杂。

于 2021-09-30T11:33:28.073 回答
1

我认为在示例中仅计算所有节点的输出的主要原因Pytorch Geometric与另一个答案中提出的“无数据切片问题”不同。train_mask您需要比包含更多节点的隐藏表示(由图卷积派生) 。因此,您不能简单地只给出这些节点的特征(分别是数据)。但是一些优化是可能的,我将在最后讨论。

我假设您设置的是节点分类(如示例代码和问题中的链接)。

例子

让我们使用一个小玩具示例,其中包含五个节点和以下边:

A<->B
B<->C
C<->D
D<->E

并假设您使用仅将节点A作为训练的 2 层 GNN。要计算 GNN 的输出A,您需要 的第一个隐藏表示B,它使用 的输入特征C。因此,您需要 2 跳邻域A来计算其输出。

可能的优化

如果您有多个训练节点(通常有)并且您有一个 k 层 GNN,它通常(并不总是将稀释的 GNN 视为示例)在 k-hop 邻域上运行。然后,您可以通过为每个训练节点组合 k-hop 邻域来计算连接的节点集。由于这取决于模型并且需要一些代码,我猜它没有包含在“示例介绍”中。无论如何,您可能只会看到对较大图表的影响,而对像 Cora 这样的图表的影响可以忽略不计。

于 2022-01-05T14:26:58.433 回答