0

我试图了解以下实现的 GNN 模型有什么问题PyTorch

class Net(torch.nn.Module):
    def __init__(self):
        super().__init__()
        
        self.conv = SAGEConv(dataset.num_features,
                             dataset.num_classes,
                             aggr="max") # max, mean, add ...)
    def forward():
        x = self.conv(data.x, data.edge_index)
        return F.log_softmax(x, dim=1)

但是在尝试运行训练循环时出现以下错误:

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-34-f3ee8050af6a> in <module>
      1 best_val_acc = test_acc = 0
      2 for epoch in range(1,100):
----> 3     train()
      4     _, val_acc, tmp_test_acc = test()
      5     if val_acc > best_val_acc:

<ipython-input-14-64df4e2a24f9> in train()
      2     model.train()
      3     optimizer.zero_grad()
----> 4     F.nll_loss(model()[data.train_mask], data.y[data.train_mask]).backward()
      5     optimizer.step()
      6 

~\anaconda3\lib\site-packages\torch\nn\modules\module.py in _call_impl(self, *input, **kwargs)
   1049         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1050                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1051             return forward_call(*input, **kwargs)
   1052         # Do not call functions when jit is used
   1053         full_backward_hooks, non_full_backward_hooks = [], []

TypeError: forward() takes 0 positional arguments but 1 was given

我正在根据要求添加有关如何调用模型的更多详细信息:

def train():
    model.train()
    optimizer.zero_grad()
    F.nll_loss(model()[data.train_mask], data.y[data.train_mask]).backward()
    optimizer.step()

def test():
    model.eval()
    logits, accs = model(), []
    for _, mask in data('train_mask', 'val_mask', 'test_mask'):
        pred = logits[mask].max(1)[1]
        acc = pred.eq(data.y[mask]).sum().item() / mask.sum().item()
        accs.append(acc)
    return accs
4

1 回答 1

1

函数torch.nn.Module.forward至少应该有一个参数:self. 在您的情况下,您有两个:self和您的 input data

    def forward(self, data): # <-
        x = self.conv(data.x, data.edge_index)
        return F.log_softmax(x, dim=1)
于 2021-10-07T09:09:29.970 回答