在用于创建消息传递网络的 PyTorch 几何教程中,在解释类的作用时,他们在开头有这一段:
MessagePassing(aggr="add", flow="source_to_target", node_dim=-2)
:定义要使用的聚合方案("add", "mean" or "max")
和消息传递的流向(或者"source_to_target"
或"target_to_source"
)。此外,该node_dim
属性指示沿哪个轴传播。
我不明白这node_dim
是指什么,以及为什么它是-2。我查看了该类的文档MessagePassing
,它在那里说它是传播的轴——这仍然没有真正阐明我们在这里做什么以及为什么默认值为 -2(大概这就是你传播信息的方式在节点级别)。有人可以向我解释一下吗?