0

如何将 pytorch 几何 torch_geometric.data.DataLoader batch_data 的节点特征重新格式化为 (batch_size, num_nodes_per_graph, feature_dim) 形状?

torch_geometric.data.DataLoader 生成具有 (total_num_nodes, feature_dim) 形状的 batch_data.x(节点特征)的每个批次。但我想在 model.forward() 函数中以(batch_size,num_nodes_per_each_sample_graph,feature_dim)的形式重新格式化它。有没有办法做到这一点?

例如,重新格式化

[1605, 512] --> [64, 40, 512]

这里的节点总数是 1605。但是每个样本并不完全有 40 个节点,所以我们必须用零填充它们。但是每个节点特征都有 512 维。

4

1 回答 1

0

我在 pytorch 几何文档中找到了答案。(使用 torch_geometric.utils.to_dense_batch )这是 API 的链接 - https://pytorch-geometric.readthedocs.io/en/latest/modules/utils.html#torch_geometric.utils.to_dense_batch

于 2021-02-19T22:49:25.887 回答