0

我有以下问题:我正在尝试在 dgl 中编写消息传递函数。一轮消息传递后,每个节点应包含以下形式的张量:

[
    [predecessor_node_id, edge_id connecting predecessor and this node, predecessor_node_id], 
    [predecessor_node_id, edge_id connecting predecessor and this node, predecessor_node_id],
    ...
]

在我的消息传递的减少阶段,我不断遇到有关向量维度的问题: RuntimeError:张量的大小必须匹配,除了维度 0。在维度 1 中得到 1 和 3(违规索引为 1) 请启发我关于我的错误行为。

最小的例子:

import dgl
import torch
import numpy as np

def initial_send(edges):
    helper = torch.stack((edges._eid, g.find_edges(edges._eid)[0], edges._eid))
    return {"send_message": helper.T}

def initial_reduce(nodes):
    return {"recieved message": nodes.mailbox["send_message"]}

if __name__ == '__main__':
    src = np.array([0, 1, 2, 2, 3, 3, 4, 3])
    dst = np.array([1, 2, 3, 1, 1, 2, 2, 2])
    g = dgl.graph((src, dst))
    g.update_all(initial_send,initial_reduce)
4

0 回答 0