0

在标准 PyTorch 文档中,它表示torch.nn将输入视为小批量。因此,对于一个样本,它建议使用input.unsqueeze(0)以添加假批次。PyTorch 几何nn模块是这种情况吗?

更具体地说,我想将具有 35 个顶点和标量边权重的全连接图提供给NNConv图层。因此,我将此图表示为Data对象,其中Data.x35x35 邻接矩阵Data.edge_index是 2 x 1225 张量,因为它是完全连接的,并且Data.edge_attr是形状 1225 x 1 的张量,因为它是完全连接的,边缘属性只是标量权重。我设计了这样一个NNConv层,我输入的不是小批量而是一个样本到网络。

nn = Sequential(Linear(1, 1225), ReLU())
self.conv1 = NNConv(35, 35, nn, aggr='mean', root_weight=True, bias=True)

在前向功能

def forward(self, data):
    x, edge_index, edge_attr = data.x, data.edge_index, data.edge_attr

    x = F.sigmoid(self.conv11(self.conv1(x, edge_index, edge_attr)))

我不明白的是我需要添加假的小批量。这是正确的还是我需要添加x.unsqueeze(0),如果是,这些Data属性中的哪一个 ( x, edge_index, edge_attr) 确实需要unsqueeze(0). 谢谢。

4

0 回答 0