如何将 pytorch 几何 torch_geometric.data.DataLoader batch_data 的节点特征重新格式化为 (batch_size, num_nodes_per_graph, feature_dim) 形状?
torch_geometric.data.DataLoader 生成具有 (total_num_nodes, feature_dim) 形状的 batch_data.x(节点特征)的每个批次。但我想在 model.forward() 函数中以(batch_size,num_nodes_per_each_sample_graph,feature_dim)的形式重新格式化它。有没有办法做到这一点?
例如,重新格式化
[1605, 512] --> [64, 40, 512]
这里的节点总数是 1605。但是每个样本并不完全有 40 个节点,所以我们必须用零填充它们。但是每个节点特征都有 512 维。