写点什么

昇腾 AI4S 图机器学习:DGL 图构建接口的 PyG 替换

作者:Splendid2025
  • 2025-06-12
    浙江
  • 本文字数:1104 字

    阅读完需:约 4 分钟

背景介绍

DGL (Deep Graph Learning) 和 PyG (Pytorch Geometric) 是两个主流的图神经网络库,它们在 API 设计和底层实现上有一定差异,在不同场景下,研究人员会使用不同的依赖库,昇腾 NPU 对 PyG 图机器学习库的支持亲和度更高,因此有些时候需要做 DGL 接口的 PyG 替换。



SE3Transformer 在 RFdiffusion 蛋白质设计模型(GitHub - RosettaCommons/RFdiffusion)作为核心组件,负责处理蛋白质结构的几何信息。其架构基于图神经网络,通过 SE(3)等变性实现对三维旋转和平移的不变性特征提取。本系列以 RFDiffusion 模型中的 SE3Transformer 为例,讲解如何将 DGL 中的接口替换为 PyG 实现。在本文中,主要展示图构建结构的替换。



DGL 图构建接口的 PyG 替换(make_full_graph 和 make_topk_graph)

make_full_graph 函数

位置:


  • rfdiffusion/util_module.py


输入:


  • xyz: 蛋白质骨架坐标,形状为 (B, L, 3) 或 (B, L, 3, 3)

  • pair: 成对特征,形状为 (B, L, L, E)

  • idx: 残基索引


输出:


  • G: DGL 图对象

  • edge_feats: 边特征


调用 DGL 函数:


  • dgl.graph: 创建图结构


数学逻辑:


  1. 提取氨基酸相对位置

  2. 构建完全连接图

  3. 设置边特征和节点特征


PyG 实现代码:


def make_full_graph(xyz, pair, idx, top_k=64, kmin=9):    B, L = xyz.shape[:2]    device = xyz.device
# 确保xyz形状正确 if xyz.dim() > 3: xyz_flat = xyz[:,:,1] if xyz.shape[2] == 3 else xyz.reshape(B, L, 3) else: xyz_flat = xyz
# 计算序列分离 sep = idx[:,None,:] - idx[:,:,None] b,i,j = torch.where(sep.abs() > 0)
# 构建PyG图所需的边索引 src = b*L+i tgt = b*L+j
# 创建图对象 G = graph((src, tgt), num_nodes=B*L).to(device)
# 计算相对位置 rel_pos = xyz_flat[b,j,:] - xyz_flat[b,i,:] if rel_pos.dim() > 2 and rel_pos.shape[-1] == 3: rel_pos = rel_pos.reshape(-1, 3) G.edata['rel_pos'] = rel_pos.detach()
# 处理边特征 edge_feats = pair[b,i,j] if edge_feats.dim() == 1: edge_feats = edge_feats.unsqueeze(-1) if edge_feats.dim() == 2: edge_feats = edge_feats.unsqueeze(-1)
# 归一化特征减少实现差异 edge_feats = torch.tanh(edge_feats / 10.0) * 10.0
return G, edge_feats
复制代码

make_topk_graph

位置:


  • rfdiffusion/util_module.py


输入与输出:


  • 与 make_full_graph 类似,但构建 k 近邻图而非完全图


调用 DGL 函数:


  • dgl.graph: 创建图结构


数学逻辑:


  1. 计算氨基酸之间距离

  2. 选择 top-k 最近邻居

  3. 确保每个节点至少有 kmin 个邻居


优化方案:


  • 使用 PyG 的 knn_graph 函数简化实现

  • 利用 PyG 的批处理机制处理多图

用户头像

Splendid2025

关注

还未添加个人签名 2025-01-26 加入

AI4SCI

评论

发布
暂无评论
昇腾AI4S图机器学习:DGL图构建接口的PyG替换_机器学习_Splendid2025_InfoQ写作社区