写点什么

昇腾 AI4S 图机器学习:DGL 消息传递接口的 PyG 替换

作者:Splendid2025
  • 2025-06-13
    浙江
  • 本文字数:1013 字

    阅读完需:约 3 分钟

背景介绍

DGL (Deep Graph Learning) 和 PyG (Pytorch Geometric) 是两个主流的图神经网络库,它们在 API 设计和底层实现上有一定差异,在不同场景下,研究人员会使用不同的依赖库,昇腾 NPU 对 PyG 图机器学习库的支持亲和度更高,因此有些时候需要做 DGL 接口的 PyG 替换。


SE3Transformer 在 RFdiffusion 蛋白质设计模型中(GitHub - RosettaCommons/RFdiffusion: Code for running RFdiffusion)作为核心组件,负责处理蛋白质结构的几何信息。其架构基于图神经网络,通过 SE(3)等变性实现对三维旋转和平移的不变性特征提取。本系列以 RFDiffusion 模型中的 SE3Transformer 为例,讲解如何将 DGL 中的接口替换为 PyG 实现。


在本文中,主要展示消息传递接口的 PyG 替换。

消息传递接口

一、边-节点消息传递 (EdgeSoftmax + Aggregation)

位置:

rfdiffusion/modules/equivariant_attention/modules.py 中的 TransformerLayer

输入:

  • 节点特征: x , 形状为(N, F)

  • 边特征: edge_attr , 形状为(E, F')

  • 图结构: graph

输出:

  • 更新的节点特征: 形状为(N, F_out)

DGL 函数:

  • dgl.nn.EdgeSoftmax:对边特征进行归一化

  • dgl.function.copy_edge:复制边特征

  • dgl.function.sum:聚合消息

数学逻辑:

1. 计算注意力分数:

2. 消息聚合:

PyG 实现:

def edge_softmax_aggregation(x, edge_index, edge_attr): 	# 计算源节点和目标节点索引	src, dst = edge_index
# 计算边softmax exp_edge_attr = torch.exp(edge_attr)
# 按目标节点归一化 node_degree = scatter_add(exp_edge_attr, dst, dim=0, dim_size=x.size(0)) norm = node_degree[dst].clamp(min=1e-6) norm_edge_attr = exp_edge_attr / norm
# 消息传递 message = norm_edge_attr * x[src]
# 聚合 out = scatter_add(message, dst, dim=0, dim_size=x.size(0))
return out
复制代码

二、矢量特征消息传递

位置:

rfdiffusion/modules/equivariant_attention/modules.py 中的 AttentionBlockSE3

输入:

  • 标量特征: feat_scalar , 形状为(N, F_s)

  • 矢量特征: feat_vector , 形状为(N, F_v, 3)

  • 图结构: graph

输出:

  • 更新的标量和矢量特征

DGL 函数:

  • dgl.nn.EdgeSoftmax:边特征 softmax

  • g.send_and_recv:消息传递与聚合

数学逻辑:


  1. 矢量特征旋转: ,其中 是相对方向

PyG 实现关键点:

  • 需要自定义消息传递函数

  • 实现等变性旋转操作

  • 处理批处理边索引

用户头像

Splendid2025

关注

还未添加个人签名 2025-01-26 加入

AI4SCI

评论

发布
暂无评论
昇腾AI4S图机器学习:DGL消息传递接口的PyG替换_Splendid2025_InfoQ写作社区