DGL(0.8.x) 技术点分析

DGL 为 Amazon 发布的图神经网络开源库(github)。支持 tensorflow, pytorch, mxnet。
如何初始化一个图:
节点 ID 从 0 开始标号
G = dgl.graph((us, vs)) 一系列点和边,us->vs
G.add_nodes(n) 添加 n 个点
G.add_edge(u, v)添加边 u->v
G.add_edges(u[s], v[s])添加边 u[s]->v[s]
节点和边都可以具有特征数据,存储为键值对,键是可选的,值必须是张量
G.ndata['x'] = th.zeros((3, 5)) 对所有节点都设置特征数据,名称为 x
G.ndata['y'] = th.randn(g.num_nodes(), 5) 不同名称的特征数据可以有不同形状
G.nodes[[0, 2]].data['x'] = th.ones((2, 5)) 对节点 0,2 设置特征数据
消息传递范式:
边上计算:计算边信息(涉及消息函数)
消息函数:接受 edges,成员包括 src, dst, data,得到发出的信息
点上计算:汇总边信息,更新点信息(涉及聚合函数,更新函数
聚合函数: 节点有属性 mailbox 访问节点收到的信息,并进行聚合操作(min, max, sum 等)
更新函数: 用聚合函数的结果对原特征进行更新
高级 API
update_all: 接受一个消息函数,一个聚合函数,一个更新函数
高效的消息传递代码
实现细节避免消耗大量内存: 大矩阵乘法分拆
在图的一部分上进行消息传递:用想囊括的节点创建一个子图

顶层提供对不同业务抽象
Backend 层: 实现多后端适配
Platform 层:适配不同架构,支持高效计算
Platform 层:适配不同架构,支持高效计算

c++层提供性能敏感功能
python 层基于 c++能力拓展更多功能,同时算子多态适配不同后端
图网络算子基于 python 层提供的运算实现



评论