写点什么

PGL 图学习之项目实践 (UniMP 算法实现论文节点分类、新冠疫苗项目实战,助力疫情)[系列九]

作者:汀丶
  • 2022-11-28
    浙江
  • 本文字数:9894 字

    阅读完需:约 32 分钟

原项目链接:https://aistudio.baidu.com/aistudio/projectdetail/5100049?contributionType=1

1.图学习技术与应用

图是一个复杂世界的通用语言,社交网络中人与人之间的连接、蛋白质分子、推荐系统中用户与物品之间的连接等等,都可以使用图来表达。图神经网络将神经网络运用至图结构中,可以被描述成消息传递的范式。百度开发了 PGL2.2,基于底层深度学习框架 paddle,给用户暴露了编程接口来实现图网络。与此同时,百度也使用了前沿的图神经网络技术针对一些应用进行模型算法的落地。本次将介绍百度的 PGL 图学习技术与应用。

1.1 图来源与建模

首先和大家分享下图学习主流的图神经网络建模方式。



14 年左右开始,学术界出现了一些基于图谱分解的技术,通过频域变换,将图变换至频域进行处理,再将处理结果变换回空域来得到图上节点的表示。后来,空域卷积借鉴了图像的二维卷积,并逐渐取代了频域图学习方法。图结构上的卷积是对节点邻居的聚合。



基于空间的图神经网络主要需要考虑两个问题:


  • 怎样表达节点特征;

  • 怎样表达一整张图。


第一个问题可以使用邻居聚合的方法,第二问题使用节点聚合来解决。



目前大部分主流的图神经网络都可以描述成消息传递的形式。需要考虑节点如何将消息发送至目标节点,然后目标节点如何对收到的节点特征进行接收。

1.2 PGL2.2 回顾介绍

PGL2.2 基于消息传递的思路构建整体框架。PGL 最底层是飞浆核心 paddle 深度学习框架。在此之上,搭建了 CPU 图引擎和 GPU 上进行 tensor 化的图引擎,来方便对图进行如图切分、图存储、图采样、图游走的算法。再上一层,会对用户暴露一些编程接口,包括底层的消息传递接口和图网络实现接口,以及高层的同构图、异构图的编程接口。框架顶层会支持几大类图模型,包括传统图表示学习中的图游走模型、消息传递类模型、知识嵌入类模型等,去支撑下游的应用场景。



最初的 PGL 是基于 paddle1.x 的版本进行开发的,所以那时候还是像 tensorflow 一样的静态图模式。目前 paddle2.0 已经进行了全面动态化,那么 PGL 也相应地做了动态图的升级。现在去定义一个图神经网络就只需要定义节点数量、边数量以及节点特征,然后将图 tensor 化即可。可以自定义如何将消息进行发送以及目标节点如何接收消息。



上图是使用 PGL 构建一个 GAT 网络的例子。最开始会去计算节点的权重,在发送消息的时候 GAT 会将原节点和目标节点特征进行求和,再加上一个非线性激活函数。在接收的时候,可以通过 reduce_softmax 对边上的权重进行归一化,再乘上 hidden state 进行加权求和。这样就可以很方便地实现一个 GAT 网络。



对于图神经网络来讲,在构建完网络后,要对它进行训练。训练方式和一般机器学习有所不同,需要根据图的规模选择适用的训练方案。



例如在小图,即图规模小于 GPU 显存的情况下,会使用 full batch 模式进行训练。它其实就是把一整张图的所有节点都放置在 GPU 上,通过一个图网络来输出所有点的特征。它的好处在于可以跑一个很深的图。这一训练方案会被应用于中小型数据集,例如 Cora、Pubmed、Citeseer、ogbn-arxiv 等。最近在 ICML 上发现了可以堆叠至 1000 层的图神经网络,同样也是在这种中小型数据集上做评估。



对于中等规模的图,即图规模大于 GPU 单卡显存,知识可以进行分片训练,每一次将一张子图塞入 GPU 上。PGL 提供了另一个方案,使用分片技术来降低显存使用的峰值。例如对一个复杂图进行计算时,它的计算复杂度取决于边计算时显存使用的峰值,此时如果有多块 GPU 就可以把边计算进行分块,每台机器只负责一小部分的计算,这样就可以大大地减少图神经网络的计算峰值,从而达到更深的图神经网络的训练。分块训练完毕后,需要通过 NCCL 来同步节点特征。



在 PGL 中,只需要一行 DistGPUGraph 命令就可以在原来 full batch 的训练代码中加入这样一个新特性,使得可以在多 GPU 中运行一个深层图神经网络。例如在 obgn-arxiv 中尝试了比较复杂的 TransformerConv 网络,如果使用单卡训练一个三层网络,其 GPU 显存会被占用近 30G,而使用分片训练就可以将它的显存峰值降低。同时,还实现了并行的计算加速,例如原来跑 100 epoch 需要十分钟,现在只需要 200 秒。



在大图的情况下,又回归到平时做数据并行的 mini batch 模式。Mini batch 与 full batch 相比最主要的问题在于它需要做邻居的采样,而邻居数目的提升会对模型的深度进行限制。这一模式适用于一些巨型数据集,包括 ogbn-products 和 ogbn-papers100m。



发现 PyG 的作者的新工作 GNNAutoScale 能够把一个图神经网络进行自动的深度扩展。它的主要思路是利用 CPU 的缓存技术,将邻居节点的特征缓存至 CPU 内存中。当训练图网络时,可以不用实时获取所有邻居的最新表达,而是获取它的历史 embedding 进行邻居聚合计算。实验发现这样做的效果还是不错的。



在工业界的情况下可能会存在更大的图规模的场景,那么这时候可能单 CPU 也存不下如此图规模的数据,这时需要一个分布式的多机存储和采样。PGL 有一套分布式的图引擎接口,使得可以轻松地在 MPI 以及 K8S 集群上通过 PGL launch 接口进行一键的分布式图引擎部署。目前也支持不同类型的邻居采样、节点遍历和图游走算法。



整体的大规模训练方式包括一个大规模分布式图引擎,中间会包含一些图采样的算子和神经网络的开发算子。顶层针对工业界大规模场景,往往需要一个 parameter server 来存储上亿级别的稀疏特征。借助 paddlefleet 的大规模参数服务器来支持超大规模的 embedding 存储。

1.3 图神经网络技术

1.3.1 节点分类任务


在算法上也进行了一些研究。图神经网络与一般机器学习场景有很大的区别。一般的机器学习假设数据之间独立同分布,但是在图网络的场景下,样本是有关联的。预测样本和训练样本有时会存在边关系。通常称这样的任务为半监督节点分类问题。



解决节点分类问题的传统方法是 LPA 标签传播算法,考虑链接关系以及标签之间的关系。另外一类方法是以 GCN 为代表的特征传播算法,只考虑特征与链接的关系。



通过实验发现在很多数据集下,训练集很难通过过拟合达到 99%的分类准确率。也就是说,训练集中的特征其实包含很大的噪声,使得网络缺乏过拟合能力。所以,想要显示地将训练 label 加入模型,因为标签可以消减大部分歧义。在训练过程中,为了避免标签泄露,提出了 UniMP 算法,把标签传播和特征传播融合起来。这一方法在三个 open graph benchmark 数据集上取得了 SOTA 的结果。



后续还把 UniMP 应用到更大规模的 KDDCup 21 的比赛中,将 UniMP 同构算法做了异构图的拓展,使其在异构图场景下进行分类任务。具体地,在节点邻居采样、批归一化和注意力机制中考虑节点之间的关系类型。

1.3.2 链接预测任务


第二个比较经典的任务是链接预测任务。目前很多人尝试使用 GNN 与 link prediction 进行融合,但是这存在两个瓶颈。首先,GNN 的深度和邻居采样的数量有关;其次,当训练像知识图谱的任务时,每一轮训练都需要遍历训练集的三元组,此时训练的复杂度和邻居节点数量存在线性关系,这就导致了如果邻居比较多,训练一个 epoch 的耗时很长。



借鉴了最近基于纯特征传播的算法,如 SGC 等图神经网络的简化方式,提出了基于关系的 embedding 传播。发现单独使用 embedding 进行特征传播在知识图谱上是行不通的。因为知识图谱上存在复杂的边关系。所以,根据不同关系下 embedding 设计了不同的 score function 进行特征传播。此外,发现之前有一篇论文提出了 OTE 的算法,在图神经网络上进行了两阶段的训练。



使用 OGBL-WikiKG2 数据集训练 OTE 模型需要超过 100 个小时,而如果切换到的特征传播算法,即先跑一次 OTE 算法,再进行 REP 特征传播,只需要 1.7 个小时就可以使模型收敛。所以 REP 带来了近 50 倍的训练效率的提升。还发现只需要正确设定 score function,大部分知识图谱算法使用的特征传播算法都会有效果上的提升;不同的算法使用 REP 也可以加速它们的收敛。



将这一套方法应用到 KDDCup 21 Wiki90M 的比赛中。为了实现比赛中要求的超大规模知识图谱的表示,做了一套大规模的知识表示工具 Graph4KG,最终在 KDDCup 中取得了冠军。

1.4 算法应用落地


PGL 在百度内部已经进行了广泛应用。包括百度搜索中的网页质量评估,会把网页构成一个动态图,并在图上进行图分类的任务。百度搜索还使用 PGL 进行网页反作弊,即对大规模节点进行检测。在文本检索应用中,尝试使用图神经网络与自然语言处理中的语言模型相结合。在其他情况下,的落地场景有推荐系统、风控、百度地图中的流量预测、POI 检索等。



本文以推荐系统为例,介绍一下平时如何将图神经网络在应用中进行落地。


推荐系统常用的算法是基于 item-based 和 user-based 协同过滤算法。Item-based 协同过滤就是推荐和 item 相似的内容,而 user-based 就是推荐相似的用户。这里最重要的是如何去衡量物品与物品之间、用户与用户之间的相似性。



可以将其与图学习结合,使用点击日志来构造图关系(包括社交关系、用户行为、物品关联),然后通过表示学习构造用户物品的向量空间。在这个空间上就可以度量物品之间的相似性,以及用户之间的相似性,进而使用其进行推荐。



常用的方法有传统的矩阵分解方法,和阿里提出的基于随机游走 + Word2Vec 的 EGES 算法。近几年兴起了使用图对比学习来获得节点表示。



在推荐算法中,主要的需求是支持复杂的结构,支持大规模的实现和快速的实验成本。希望有一个工具包可以解决 GNN + 表示学习的问题。所以,对现有的图表示学习算法进行了抽象。具体地,将图表示学习分成了四个部分。第一部分是图的类型,将其分为同构图、异构图、二部图,并在图中定义了多种关系,例如点击关系、关注关系等。第二,实现了不同的样本采样的方法,包括在同构图中常用的 node2Vec 以及异构图中按照用户自定义的 meta path 进行采样。第三部分是节点的表示。可以根据 id 去表示节点,也可以通过图采样使用子图来表示一个节点。还构造了四种 GNN 的聚合方式。



发现不同场景以及不同的图表示的训练方式下,模型效果差异较大。所以的工具还支持大规模稀疏特征 side-info 的支持来进行更丰富的特征组合。用户可能有很多不同的字段,有些字段可能是缺失的,此时只需要通过一个配置表来配置节点包含的特征以及字段即可。还支持 GNN 的异构图自动扩展。你可以自定义边关系,如点击关系、购买关系、关注关系等,并选取合适的聚合方式,如 lightgcn,就可以自动的对 GNN 进行异构图扩展,使 lightgcn 变为 relation-wise 的 lightgcn。



对工具进行了瓶颈分析,发现它主要集中在分布式训练中图采样和负样本构造中。可以通过使用 In-Batch Negative 的方法进行优化,即在 batch 内走负采样,减少通讯开销。这一优化可以使得训练速度提升四至五倍,而且在训练效果上几乎是无损的。此外,在图采样中可以通过对样本重构来降低采样的次数,得到两倍左右的速度提升,且训练效果基本持平。相比于市面上现有的分布式图表示工具,还可以实现单机、双机、四机甚至更多机器的扩展。



不仅如此,还发现游走类模型训练速度较快,比较适合作为优秀的热启动参数。具体地,可以先运行一次 metapath2Vce 算法,将训练得到的 embedding 作为初始化参数送入 GNN 中作为热启动的节点表示。发现这样做在效果上有一定的提升。

1.5 Q&A

Q1:在特征在多卡之间传递的训练模式中,使用 push 和 pull 的方式通讯时间占比大概有多大?


A:通讯时间的占比挺大的。如果是特别简单的模型,如 GCN 等,那么使用这种方法训练,通讯时间甚至会比直接跑这个模型的训练时间还要久。所以这一方法适合复杂模型,即模型计算较多,且通讯中特征传递的数据量相比来说较小,这种情况下就比较适合这种分布式计算。


Q2:图学习中节点邻居数较多会不会导致特征过平滑?


A:这里采用的方法很多时候都很暴力,即直接使用 attention 加多头的机制,这样会极大地减缓过平滑问题。因为使用 attention 机制会使得少量特征被 softmax 激活;多头的方式可以使得每个头学到的激活特征不一样。所以这样做一定比直接使用 GCN 进行聚合会好。


Q3:百度有没有使用图学习在自然语言处理领域的成功经验?


A:之前有类似的工作,你可以关注 ERINESage 这篇论文。它主要是将图网络和预训练语言模型进行结合。也将图神经网络落地到了例如搜索、推荐的场景。因为语言模型本身很难对用户日志中包含的点击关系进行建模,通过图神经网络就可以将点击日志中的后验关系融入语言模型,进而得到较大的提升。


Q4:能详细介绍一下 KDD 比赛中将同构图拓展至异构图的 UniMP 方法吗?


A:首先,每一个关系类型其实应该有不同的邻居采样方法。例如 paper 到 author 的关系,会单独地根据它来采样邻居节点。如果按照同构图的方式来采样,目标节点的邻居节点可能是论文,也可能是作者或者机构,那么采样的节点是不均匀的。其次,在批归一化中按照关系 channel 来进行归一化,因为如果你将 paper 节点和 author 节点同时归一化,由于它们的统计均值和方差不一样,那么这种做法会把两者的统计量同时带骗。同理,在聚合操作中,不同的关系对两个节点的作用不同,需要按照不同关系使用不同的 attention 注意力权重来聚合特征。

2.基于 UniMP 算法实现论文引用网络节点分类任务

图学习之基于 PGL-UniMP 算法的论文引用网络节点分类任务:https://aistudio.baidu.com/aistudio/projectdetail/5116458?contributionType=1


由于文章篇幅问题,为了让学习者有更好的体验,这里新开一个项目完成这个任务。


Epoch 987 Train Acc 0.7554459 Valid Acc 0.7546095Epoch 988 Train Acc 0.7537374 Valid Acc 0.75717235Epoch 989 Train Acc 0.75497127 Valid Acc 0.7573859Epoch 990 Train Acc 0.7611409 Valid Acc 0.75653166Epoch 991 Train Acc 0.75316787 Valid Acc 0.75489426Epoch 992 Train Acc 0.749561 Valid Acc 0.7547519Epoch 993 Train Acc 0.7571544 Valid Acc 0.7551079Epoch 994 Train Acc 0.7516492 Valid Acc 0.75581974Epoch 995 Train Acc 0.7563476 Valid Acc 0.7563181Epoch 996 Train Acc 0.7504627 Valid Acc 0.7538976Epoch 997 Train Acc 0.7476152 Valid Acc 0.75439596Epoch 998 Train Acc 0.7539272 Valid Acc 0.7528298Epoch 999 Train Acc 0.7532153 Valid Acc 0.75396883
复制代码

3.新冠疫苗项目实战,助力疫情

Kaggle 新冠疫苗研发竞赛:https://www.kaggle.com/c/stanford-covid-vaccine/overview



mRNA 疫苗已经成为 2019 冠状病毒最快的候选疫苗,但目前它们面临着关键的潜在限制。目前最大的挑战之一是如何设计超稳定的 RNA 分子(mRNA)。传统疫苗是装在注射器里通过冷藏运输到世界各地,但 mRNA 疫苗目前还不可能做到这一点。


研究人员已经观察到 RNA 分子有降解的倾向。这是一个严重的限制,降解会使 mRNA 疫苗失效。目前,对于特定 RNA 的主干中哪个部位最容易受影响的细节知之甚少。在不了解这些情况的情况下,目前针对 COVID-19 的 mRNA 疫苗必须在高度冷藏条件下准备和运输,它们必须能够得到稳定,否则不太可能送达地球上的每个人。


由斯坦福大学医学院(Stanford’s School of Medicine)计算生物学家瑞朱·达斯(Rhiju Das)教授领导的永恒星系(Eterna)社区将科学家和竞赛玩家聚集在一起,解决谜题并发明药物。Eterna 是一款在线竞赛平台,通过谜题挑战玩家解决诸如 mRNA 设计等科学问题。由斯坦福大学的研究人员合成并进行实验测试,以获得关于 RNA 分子的新见解。Eterna 社区之前已经开启了新的科学原理,对致命疾病做出了新的诊断,并利用世界上最强大的智力资源改善公众生活。Eterna 社区通过其在 20 多份出版物上的贡献推动了生物技术,包括 RNA 生物技术进展。


在这次竞赛中,我们希望利用 Kaggle 社区的数据科学专业知识来开发模型和设计 RNA 降解规则。模型将预测 RNA 分子每个碱基的可能降解率,训练的对象是由超过 3000 个 RNA 分子组成的 Eterna 数据集子集(它们跨越了一整套序列和结构),以及它们在每个位置的降解率。然后,我们将根据 Eterna 玩家刚刚为 COVID-19 mRNA 疫苗设计的第二代 RNA 序列为模型评分。这些最终的测试序列目前正在合成和实验表征在斯坦福大学与建模工作并行——自然将评分模型!


提高 mRNA 疫苗的稳定性已经在探索,我们必须解决这一深刻的科学挑战,以加速 mRNA 疫苗研究,并提供一种针对 COVID-19 背后病毒 SARS-CoV-2 的冰箱稳定疫苗。我们正在试图解决的问题希望得到学术实验室、工业研发团队和超级计算机的帮助,你可以加入电子竞赛玩家、科学家和开发者的团队,在 Eterna 永恒星球上对抗这一毁灭性病毒。

3.1 案例简介

将编码的 DNA 送到细胞中,细胞使用 mRNA(Messenger RNA)组装蛋白,免疫系统检测到组装蛋白质以后,利用构建病毒蛋白的编码基因激活免疫系统产生抗体,增强针对冠状病毒的抵御能力。


不同的 mRNA 生成同一个蛋白质,



mRNA 随着时间的流逝及温度的变化发生了降解,



如何找到结构更加稳定的 mRNA?利用图神经网络找到更稳定的 mRNA,颜色越深越稳定.



3.2 新冠疫苗项目拔高实战

数据分布特征


查看当前挂载的数据集目录


# 加载一些需要用到的模块,设置随机数import jsonimport randomimport numpy as npimport pandas as pd
import matplotlib.pyplot as pltimport networkx as nx
from utils.config import prepare_config, make_dirfrom utils.logger import prepare_logger, log_to_filefrom data_parser import GraphParser
seed = 123np.random.seed(seed)random.seed(seed)
复制代码


# https://www.kaggle.com/c/stanford-covid-vaccine/data# 加载训练用的数据df = pd.read_json('../data/data179441/train.json', lines=True)# 查看一下数据集的内容sample = df.loc[0]print(sample)
index 400id id_2a7a4496fsequence GGAAAGCCCGCGGCGCCGGGCGCCGCGGCCGCCCAGGCCGCCCGGC...structure .....(((...)))((((((((((((((((((((.((((....)))...predicted_loop_type EEEEESSSHHHSSSSSSSSSSSSSSSSSSSSSSSISSSSHHHHSSS...signal_to_noise 0SN_filter 0seq_length 107seq_scored 68reactivity_error [146151.225, 146151.225, 146151.225, 146151.22...deg_error_Mg_pH10 [104235.1742, 104235.1742, 104235.1742, 104235...deg_error_pH10 [222620.9531, 222620.9531, 222620.9531, 222620...deg_error_Mg_50C [171525.3217, 171525.3217, 171525.3217, 171525...deg_error_50C [191738.0886, 191738.0886, 191738.0886, 191738...reactivity [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...deg_Mg_pH10 [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...deg_pH10 [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...deg_Mg_50C [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...deg_50C [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...Name: 0, dtype: object
复制代码


例如 deg_50C、deg_Mg_50C 这样的值全为 0 的行,就是我们需要预测的。


structure 一行,数据中的括号是为了构成边用的。


本案例要预测 RNA 序列不同位置的降解速率,训练数据中提供了多个 ground 值,标签包括以下几项:reactivity, deg_Mg_pH10, and deg_Mg_50


  • reactivity - (1x68 vector 训练集,1x91 测试集) 一个浮点数数组,与 seq_scores 有相同的长度,是前 68 个碱基的反应活性值,按顺序表示,用于确定 RNA 样本可能的二级结构。

  • deg_Mg_pH10 - (训练集 1x68 向量,1x91 测试集)一个浮点数数组,与 seq_scores 有相同的长度,是前 68 个碱基的反应活性值,按顺序表示,用于确定在高 pH (pH 10)下的降解可能性。

  • deg_Mg_50 - (训练集 1x68 向量,1x91 测试集)一个浮点数数组,与 seq_scores 有相同的长度,是前 68 个碱基的反应活性值,按顺序表示,用于确定在高温(50 摄氏度)下的降解可能性。


# 利用GraphParser构造图结构的数据args = prepare_config("./config.yaml", isCreate=False, isSave=False)parser = GraphParser(args) # GraphParser类来自data_parser.pygdata = parser.parse(sample) # GraphParser里最主要的函数就是parse(self, sample)
复制代码


数据格式:


{'nfeat': array([[0., 0., 0., ..., 0., 0., 0.],        [0., 0., 0., ..., 0., 0., 0.],        [0., 1., 0., ..., 0., 0., 0.],        ...,        [1., 0., 0., ..., 0., 0., 0.],        [1., 0., 0., ..., 0., 0., 0.],        [1., 0., 0., ..., 0., 0., 0.]], dtype=float32), 'edges': array([[  0,   1],        [  1,   0],        [  1,   2],        ...,        [142, 105],        [106, 142],        [142, 106]]), 'efeat': array([[ 0.,  0.,  0.,  1.,  1.],        [ 0.,  0.,  0., -1.,  1.],        [ 0.,  0.,  0.,  1.,  1.],        ...,        [ 0.,  1.,  0.,  0.,  0.],        [ 0.,  1.,  0.,  0.,  0.],        [ 0.,  1.,  0.,  0.,  0.]], dtype=float32), 'labels': array([[ 0.    ,  0.    ,  0.    ],        [ 0.    ,  0.    ,  0.    ],        ...,        [ 0.    ,  0.9213,  0.    ],        [ 6.8894,  3.5097,  5.7754],        [ 0.    ,  1.8426,  6.0642],          ...,                [ 0.    ,  0.    ,  0.    ],        [ 0.    ,  0.    ,  0.    ]], dtype=float32), 'mask': array([[ True],        [ True],     ......       [False]])}
复制代码


# 图数据可视化fig = plt.figure(figsize=(24, 12))nx_G = nx.Graph()nx_G.add_nodes_from([i for i in range(len(gdata['nfeat']))])
nx_G.add_edges_from(gdata['edges'])node_color = ['g' for _ in range(sample['seq_length'])] + \['y' for _ in range(len(gdata['nfeat']) - sample['seq_length'])]options = { "node_color": node_color,}pos = nx.spring_layout(nx_G, iterations=400, k=0.2)nx.draw(nx_G, pos, **options)
plt.show()
复制代码



从图中可以看到,绿色节点是碱基,黄色节点是密码子。


结果返回的是 MCRMSE 和 loss
{'MCRMSE': 0.5496759, 'loss': 0.3025484172316889}
复制代码


[DEBUG] 2022-11-25 17:50:42,468 [ trainer.py: 66]: {'MCRMSE': 0.5496759, 'loss': 0.3025484172316889}[DEBUG] 2022-11-25 17:50:42,468 [ trainer.py: 73]: write to tensorboard ../checkpoints/covid19/eval_history/eval[DEBUG] 2022-11-25 17:50:42,469 [ trainer.py: 73]: write to tensorboard ../checkpoints/covid19/eval_history/eval[INFO] 2022-11-25 17:50:42,469 [ trainer.py: 76]: [Eval:eval]:MCRMSE:0.5496758818626404 loss:0.3025484172316889[INFO] 2022-11-25 17:50:42,602 [monitored_executor.py: 606]: ********** Stop Loop ************[DEBUG] 2022-11-25 17:50:42,607 [monitored_executor.py: 199]: saving step 12500 to ../checkpoints/covid19/model_12500



这部分代码实现参考项目:[PGL图学习之基于GNN模型新冠疫苗任务[系列九]](https://aistudio.baidu.com/aistudio/projectdetail/5123296?contributionType=1)
复制代码


# 我们在 layer.py 里定义了一个新的 gnn 模型(my_gnn),消息传递的过程中加入了边的特征(edge_feat)# 然后修改 model.py 里的 GNNModel# 使用修改后的模型,运行 main.py。为节省时间,设置 epochs = 100
# !python main.py --config config.yaml #训练#!python main.py --mode infer #预测
复制代码

4.总结

本项目讲了论文节点分类任务和新冠疫苗任务,并在论文节点分类任务中对代码进行详细讲解。PGL 八九系列的项目耦合性比较大,也花了挺久时间研究希望对大家有帮助。


后续将做一次大的总结偏向业务侧该如何落地以及图算法的归纳,之后会进行不定期更新图相关的算法!


  • easydict 库和 collections 库!

  • 从官方数据处理部分,学习到利用 np 的 vstack 实现自环边以及知道有向边如何添加反向边的数据——这样的一种代码实现边数据转换的方式!

  • 从模型加载部分,学习了多 program 执行的操作,理清了 program 与命名空间之间的联系!

  • 从模型训练部分,强化了执行器执行时,需要传入正确的 program 以及 feed_dict,在 pgl 中可以使用图 Graph 自带的 to_feed 方法返回一个 feed_dict 数据字典作为初始数据,后边再按需添加新数据!

  • 从 model.py 学习了模型的组网,以及 pgl 中 conv 类下的网络模型方法的调用,方便组网!

  • 重点来了:从 build_model.py 学习了模型的参数的加载组合,实现统一的处理和返回统一的算子以及参数!

发布于: 刚刚阅读数: 4
用户头像

汀丶

关注

本博客将不定期更新关于NLP等领域相关知识 2022-01-06 加入

本博客将不定期更新关于机器学习、强化学习、数据挖掘以及NLP等领域相关知识,以及分享自己学习到的知识技能,感谢大家关注!

评论

发布
暂无评论
PGL图学习之项目实践(UniMP算法实现论文节点分类、新冠疫苗项目实战,助力疫情)[系列九]_图神经网络_汀丶_InfoQ写作社区