写点什么

昇腾 910-PyTorch 实现 图神经网络 GraphSage

  • 2025-05-26
    上海
  • 本文字数:4254 字

    阅读完需:约 14 分钟

基于 Pytorch Gemotric 在昇腾上实现 GraphSage 图神经网络

本实验主要介绍了如何在昇腾上,使用 pytorch 对经典的图神经网络 GraphSage 在论文引用 CiteSeer 数据集上进行分类训练的实战讲解。


内容包括 GraphSage 创新点分析、GraphSage 算法原理、GraphSage 网络架构剖析与 GraphSage 网络模型代码实战分析等等。


本实验的目录结构安排如下所示:


  • GraphSage 创新点分析

  • GraphSage 算法原理

  • GraphSage 网络架构剖析

  • GraphSage 网络用于 CiteSeer 数据集分类实战

GraphSage 创新点分析

  • 本文提出了一种归纳式学习模型,可以得到新点/新图的表征。

  • 模型通过学习一组函数来得到点的表征。之前的随机游走方式则是先随机初始化点的表征,然后通过模型的训练更新点的表征来获取点的表征,这样无法进行归纳式学习。

  • 采样并汇聚点的邻居特征与节点的特征拼接得到点的特征。

  • GraphSAGE 算法在直推式和归纳式学习均达到最优效果。

GraphSage 算法原理

GCN 网络每次学习都需要将整个图送入显存/内存中,资源消耗巨大。另外使用整个图结构进行学习,导致了 GCN 的学习的固化,图中一旦新增节点,整个图的学习都需要重新进行。这两点对于大数据集和项目实际落地来说,是巨大的阻碍。


我们知道,GCN 网络的每一次卷积的过程,每个节点都是只与自己周围的信息节点进行交互,对于单层的 GCN 网络而言,每个节点只能够接触到自己一跳以内的节点信息,两层的话则是两跳,通过一跳邻居间接传播过来。


假设需要图中一个点的卷积的结果,并且只使用了一层 GCN,那么实际上我只需要这个节点和它的全部邻居节点,图中的其他节点是没有意义的;如果使用了两层 GCN,那么我只需要这个节点和它的两跳以内的全部邻居,以此类推。因为更远节点的信息根本无法到达这个节点。


此外,当图中的某些节点是有几百上千个邻居的超级节点,对于这种节点,哪怕只进行一跳邻居采样,仍然会导致计算困难。GraphSage 网络采用抽取出一部分待训练节点和它们的 N 跳内邻居可以完成这些节点的训练,与此同时对于超级节点的情况,提出了采样的思路,即只对目标节点的邻居进行一定数量的采样,然后通过这次被采样出来的节点和目标节点进行计算,从而逼近完全聚合的效果。


GraphSAGE 模型,是一种在图上的通用的归纳式的框架,利用节点特征信息(例如文本属性)来高效地为训练阶段未见节点生成 embedding。该模型学习的不是节点的 embedding 向量,而是学习一种聚合方式,即如何通过从一个节点的局部邻居采样并聚合顶点特征,得到节点最终 embedding 表征。 当学习到适合的聚合函数后,可以迅速应用到未见过的图上,得到未见过的节点 embedding。


因此采样聚合是 GraphSage 网络的两大主要工作,通过随机采样的方式从整张图中抽出一张子图近似替换原始图,然后在该子图上进行聚合计算提取信息特征。


/*** 算法流程解读 ***/


第一个 for 循环针对层数进行遍历,表示进行多少层的 GraphSAGE, 第二个 for 循环用于遍历 Graph 中的所有节点, 针对每个节点, 对邻居进行采样得到邻居节点集合,然后遍历该集合使用对邻居节点信息进行聚合得到 $\mathbf{h}{N(v)}^{k}\mathbf{h}{u}^{k - 1}$进行拼接, 经过非线性变换后赋得到 v 节点在当前 k 层的节点权重值。


GraphSage 网络架构剖析

GraphSAGE 不是试图学习一个图上所有 node 的 embedding,而是学习一个为每个 node 产生 embedding 的映射。 GraphSage 框架中包含两个很重要的操作:Sample 采样和 Aggregate 聚合。这也是其名字 GraphSage(Graph SAmple and aggreGatE)的由来。GraphSAGE 主要分两步:采样、聚合。GraphSAGE 的采样方式是邻居采样,邻居采样的意思是在某个节点的邻居节点中选择几个节点作为原节点的一阶邻居,之后对在新采样的节点的邻居中继续选择节点作为原节点的二阶节点,以此类推。


文中不是对每个顶点都训练一个单独的 embeddding 向量,而是训练了一组 aggregator functions,这些函数学习如何从一个顶点的局部邻居聚合特征信息。每个聚合函数从一个顶点的不同的 hops 或者说不同的搜索深度聚合信息。测试或是推断的时候,使用训练好的系统,通过学习到的聚合函数来对完全未见过的顶点生成 embedding。


上图包含下述三个步骤:


  • 对图中每个顶点邻居顶点进行采样,因为每个节点的度是不一致的,为了计算高效, 为每个节点采样固定数量的邻居

  • 根据聚合函数聚合邻居顶点蕴含的信息

  • 得到图中各顶点的向量表示供下游任务使用

GraphSage 网络用于 CiteSeer 数据集分类实战

导入 torch 相关库,'functional'集成了一些非线性算子函数,例如 Relu 等。


import torchimport torch.nn.functional as F
复制代码


由于 torch_geometric 中集成了单层的 SAGEConv 模块,这里直接进行导入,若有兴趣可以自行实现该类,注意输入与输出对齐即可。此外,数据集用的是 CiteSeer,该数据集也直接集成在 Planetoid 模块中,这里也需要将其 import 进来。


# 导入GraphSAGE层from torch_geometric.nn import SAGEConvfrom torch_geometric.datasets import Planetoid
复制代码


本实验需要跑在 Npu 上,因此将 Npu 相关库导入,'transfer_to_npu'可以使模型快速的迁移到 Npu 上进行运行。


#导入Npu相关库import torch_npufrom torch_npu.contrib import transfer_to_npu
复制代码

CiteSeer 数据集介绍


CiteSeer 是一个学术论文数据集,主要涉及计算机科学领域。它由 NEC 研究院开发,基于自动引文索引(ACI)机制,提供了一种通过引文链接来检索文献的方式。


CiteSeer 数据集由学术论文组成,每篇论文被视为一个节点,引用关系被视为边,总共包含 6 个类别,对应图中 6 种不同颜色的节点。


包含 3327 篇论文也就是 3327 个节点,每个节点有一个 3703 维的二进制特征向量,用来表示论文的内容,其中特征向量采用词袋模型(Bag of Words)表示,即每个特征维度对应一个词汇表中的词,值为 1 表示该词在论文中出现,为 0 表示未出现。在这些论文中共组成 4732 条边表示论文与论文的引用关系。6 个类别分别是 Agents、Artificial、Intelligence、Database、Information Retrieval、Machine Learning 与 Human-Computer Interaction。


加载数据集,root 为下载路径的保存默认保存位置,若下载不下来可手动下载后保存在指定路径即可。


print("===== begin Download Dadasat=====\n")dataset = Planetoid(root='/home/pengyongrong/workspace/data', name='CiteSeer')print("===== Download Dadasat finished=====\n")
print("dataset num_features is: ", dataset.num_features)print("dataset.num_classes is: ", dataset.num_classes)
print("dataset.edge_index is: ", dataset.edge_index)
print("train data is: ", dataset.data)print("dataset0 is: ", dataset[0])
print("train data mask is: ", dataset.train_mask, "num train is: ", (dataset.train_mask ==True).sum().item())print("val data mask is: ",dataset.val_mask, "num val is: ", (dataset.val_mask ==True).sum().item())print("test data mask is: ",dataset.test_mask, "num test is: ", (dataset.test_mask ==True).sum().item())
复制代码


===== begin Download Dadasat=====
===== Download Dadasat finished=====
dataset num_features is: 3703dataset.num_classes is: 6dataset.edge_index is: tensor([[ 628, 158, 486, ..., 2820, 1643, 33], [ 0, 1, 1, ..., 3324, 3325, 3326]])train data is: Data(x=[3327, 3703], edge_index=[2, 9104], y=[3327], train_mask=[3327], val_mask=[3327], test_mask=[3327])dataset0 is: Data(x=[3327, 3703], edge_index=[2, 9104], y=[3327], train_mask=[3327], val_mask=[3327], test_mask=[3327])train data mask is: tensor([ True, True, True, ..., False, False, False]) num train is: 120val data mask is: tensor([False, False, False, ..., False, False, False]) num val is: 500test data mask is: tensor([False, False, False, ..., True, True, True]) num test is: 1000
复制代码


搭建两层 GraphSAGE 网络,其中 sage1 与 sage2 分别表示第一层与第二层,这里如果有需要可以搭建多层的 GraphSage,注意保持输入链接出大小相互匹配即可。


class GraphSAGE_NET(torch.nn.Module):
def __init__(self, feature, hidden, classes): super(GraphSAGE_NET, self).__init__() self.sage1 = SAGEConv(feature, hidden) self.sage2 = SAGEConv(hidden, classes)
def forward(self, data): x, edge_index = data.x, data.edge_index
x = self.sage1(x, edge_index) x = F.relu(x) x = F.dropout(x, training=self.training) x = self.sage2(x, edge_index)
return F.log_softmax(x, dim=1)
复制代码


定义设备跑在 Npu 上,这里如果需要替换成 Gpu 或 Cpu,则替换成'cuda'或'cpu'即可。


device = 'npu'
复制代码


定义 GraphSAGE 网络,中间隐藏层节点个数定义为 16,'dataset.num_classes'为先前数据集中总的类别数,这里是 7 类。'to()'的作用是将该加载到指定模型设备上。优化器用的是'optim'中的'Adam'。


model = GraphSAGE_NET(dataset.num_node_features, 16, dataset.num_classes).to(device) data = dataset[0].to(device)optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)
复制代码


开始训练模型,指定训练次数 200 次,训练后采用极大似然用作损失函数计算损失,然后进行反向传播更新模型的参数,训练完成后,用验证集中的数据对模型效果进行验证,最后打印模型的准确率为 0.665。


model.train()for epoch in range(200):    optimizer.zero_grad()    out = model(data)    loss = F.nll_loss(out[data.train_mask], data.y[data.train_mask])    loss.backward()    optimizer.step()

model.eval()_, pred = model(data).max(dim=1)correct = int(pred[data.test_mask].eq(data.y[data.test_mask]).sum().item())acc = correct / int(data.test_mask.sum())print('GraphSAGE Accuracy: {:.4f}'.format(acc))
复制代码


GraphSAGE Accuracy: 0.6650
复制代码


内存使用情况: 整个训练过程的内存使用情况可以通过"npu-smi info"命令在终端查看,因此本文实验只用到了单个 npu 卡(也就是 chip 0),内存占用约 943M,对内存、精度或性能优化有兴趣的可以自行尝试进行优化。


Reference

[1] Hamilton, William L , R. Ying , and J. Leskovec . "Inductive Representation Learning on Large Graphs." (2017).


用户头像

还未添加个人签名 2024-12-19 加入

还未添加个人简介

评论

发布
暂无评论
昇腾910-PyTorch 实现 图神经网络GraphSage_永荣带你玩转昇腾_InfoQ写作社区