在标准 PyTorch 文档中,它表示torch.nn
将输入视为小批量。因此,对于一个样本,它建议使用input.unsqueeze(0)
以添加假批次。PyTorch 几何nn
模块是这种情况吗?
更具体地说,我想将具有 35 个顶点和标量边权重的全连接图提供给NNConv
图层。因此,我将此图表示为Data
对象,其中Data.x
35x35 邻接矩阵Data.edge_index
是 2 x 1225 张量,因为它是完全连接的,并且Data.edge_attr
是形状 1225 x 1 的张量,因为它是完全连接的,边缘属性只是标量权重。我设计了这样一个NNConv
层,我输入的不是小批量而是一个样本到网络。
nn = Sequential(Linear(1, 1225), ReLU())
self.conv1 = NNConv(35, 35, nn, aggr='mean', root_weight=True, bias=True)
在前向功能
def forward(self, data):
x, edge_index, edge_attr = data.x, data.edge_index, data.edge_attr
x = F.sigmoid(self.conv11(self.conv1(x, edge_index, edge_attr)))
我不明白的是我需要添加假的小批量。这是正确的还是我需要添加x.unsqueeze(0)
,如果是,这些Data
属性中的哪一个 ( x
, edge_index
, edge_attr
) 确实需要unsqueeze(0)
. 谢谢。