将 GNN 与 HetreoGraph 一起使用是否正确?
运行代码几次会报错,其他次都正常。
重现
多次运行代码会报错:DGLError: Expected data to have X rows, got Y。
import numpy as np
import torch
import torch.optim as optim
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as Data
import dgl
from dgl import DGLGraph
import dgl.nn as dglnn
class RGNN(nn.Module):
def __init__(self, in_size, hid_size, out_size, rel_names):
super().__init__()
self.conv1 = dglnn.HeteroGraphConv({rel : dglnn.GraphConv(in_size, hid_size) \
for rel in rel_names}, aggregate="sum")
self.conv2 = dglnn.HeteroGraphConv({rel : dglnn.GraphConv(hid_size, out_size) \
for rel in rel_names}, aggregate="sum")
self.conv3 = dglnn.HeteroGraphConv({rel : dglnn.GraphConv(out_size, out_size) \
for rel in rel_names}, aggregate="sum")
self.dropout=nn.Dropout(p=0.2)
def forward(self, blocks, in_feats):
print("before conv1", in_feats["user"].shape, in_feats["game"].shape)
h = self.conv1(blocks[0], in_feats)
print("after conv1", h["user"].shape, h["game"].shape)
h = self.conv2(blocks[1], h)
print("after conv2", h["user"].shape, h["game"].shape)
h = self.conv3(blocks[2], h)
print("conv3", h["user"].shape, h["game"].shape)
return h
hetero_frontier = dgl.heterograph({
('user', 'follow', 'user'): ([1, 3, 7,1,2,3,4,5,6,7,8,9], [3, 6, 8,4,5,6,7,3,2,5,8,5]),
('user', 'play', 'game'): ([5, 5, 4] + list(range(10)), [6, 6, 2] + list(np.random.randint(0, 10,(1,10)).squeeze()))},
num_nodes_dict={'user': 10, 'game': 10})
embed_size = 128
hetero_frontier.nodes['user'].data['embedding'] = torch.FloatTensor(np.random.normal(0, 0.01, (10, embed_size)))
hetero_frontier.nodes['game'].data['embedding'] = torch.FloatTensor(np.random.normal(0, 0.01, (10, embed_size)))
model = RGNN(embed_size, embed_size, embed_size, hetero_frontier.etypes)
train_nid_dict = {"user":list(np.random.randint(0,10,(1,5)).squeeze()),"game":np.random.randint(10)}
sampler = dgl.dataloading.MultiLayerFullNeighborSampler(3)
dataloader = dgl.dataloading.NodeDataLoader(hetero_frontier, train_nid_dict, sampler,
batch_size=128,
shuffle=False,
drop_last=False,
num_workers=1)
for input_nodes, output_nodes, blocks in dataloader:
print("input nodes", input_nodes["user"].shape,input_nodes["game"].shape)
for idx, block in enumerate(blocks):
print("block ",idx,block.srcdata["embedding"]["user"].shape,block.srcdata["embedding"]["game"].shape)
print("ouput nodes", output_nodes["user"].shape,output_nodes["game"].shape)
input_features = blocks[0].srcdata["embedding"] # returns a dict
output_features = model(blocks, input_features)
DGLError: Expected data to have 5 rows, got 4. 但是我遵循指南
预期行为
无尺寸错误
环境
- DGL 版本 0.6.1
- PyTorch 1.8.1
- 苹果系统:
- 你是如何安装 DGL 的:conda
- Python版本:3.7
- CUDA/cuDNN 版本(如果适用):仅 CPU