昇腾 AI4S 图机器学习:DGL 图构建接口的 PyG 替换
背景介绍
DGL (Deep Graph Learning) 和 PyG (Pytorch Geometric) 是两个主流的图神经网络库,它们在 API 设计和底层实现上有一定差异,在不同场景下,研究人员会使用不同的依赖库,昇腾 NPU 对 PyG 图机器学习库的支持亲和度更高,因此有些时候需要做 DGL 接口的 PyG 替换。

SE3Transformer 在 RFdiffusion 蛋白质设计模型(GitHub - RosettaCommons/RFdiffusion)作为核心组件,负责处理蛋白质结构的几何信息。其架构基于图神经网络,通过 SE(3)等变性实现对三维旋转和平移的不变性特征提取。本系列以 RFDiffusion 模型中的 SE3Transformer 为例,讲解如何将 DGL 中的接口替换为 PyG 实现。在本文中,主要展示图构建结构的替换。

DGL 图构建接口的 PyG 替换(make_full_graph 和 make_topk_graph)
make_full_graph 函数
位置:
rfdiffusion/util_module.py
输入:
xyz: 蛋白质骨架坐标,形状为 (B, L, 3) 或 (B, L, 3, 3)
pair: 成对特征,形状为 (B, L, L, E)
idx: 残基索引
输出:
G: DGL 图对象
edge_feats: 边特征
调用 DGL 函数:
dgl.graph: 创建图结构
数学逻辑:
提取氨基酸相对位置
构建完全连接图
设置边特征和节点特征
PyG 实现代码:
make_topk_graph
位置:
rfdiffusion/util_module.py
输入与输出:
与 make_full_graph 类似,但构建 k 近邻图而非完全图
调用 DGL 函数:
dgl.graph: 创建图结构
数学逻辑:
计算氨基酸之间距离
选择 top-k 最近邻居
确保每个节点至少有 kmin 个邻居
优化方案:
使用 PyG 的 knn_graph 函数简化实现
利用 PyG 的批处理机制处理多图
评论