图神经网络之预训练大模型结合:ERNIESage 在链接预测任务应用
1.ERNIESage 运行实例介绍(1.8x 版本)
本项目原链接:https://aistudio.baidu.com/aistudio/projectdetail/5097085?contributionType=1
本项目主要是为了直接提供一个可以运行 ERNIESage 模型的环境,
https://github.com/PaddlePaddle/PGL/blob/develop/examples/erniesage/README.md
在很多工业应用中,往往出现如下图所示的一种特殊的图:Text Graph。顾名思义,图的节点属性由文本构成,而边的构建提供了结构信息。如搜索场景下的 Text Graph,节点可由搜索词、网页标题、网页正文来表达,用户反馈和超链信息则可构成边关系。
ERNIESage 由 PGL 团队提出,是 ERNIE SAmple aggreGatE 的简称,该模型可以同时建模文本语义与图结构信息,有效提升 Text Graph 的应用效果。其中 ERNIE 是百度推出的基于知识增强的持续学习语义理解框架。
ERNIESage 是 ERNIE 与 GraphSAGE 碰撞的结果,是 ERNIE SAmple aggreGatE 的简称,它的结构如下图所示,主要思想是通过 ERNIE 作为聚合函数(Aggregators),建模自身节点和邻居节点的语义与结构关系。ERNIESage 对于文本的建模是构建在邻居聚合的阶段,中心节点文本会与所有邻居节点文本进行拼接;然后通过预训练的 ERNIE 模型进行消息汇聚,捕捉中心节点以及邻居节点之间的相互关系;最后使用 ERNIESage 搭配独特的邻居互相看不见的 Attention Mask 和独立的 Position Embedding 体系,就可以轻松构建 TextGraph 中句子之间以及词之间的关系。
使用 ID 特征的 GraphSAGE 只能够建模图的结构信息,而单独的 ERNIE 只能处理文本信息。通过 PGL 搭建的图与文本的桥梁,ERNIESage 能够很简单的把 GraphSAGE 以及 ERNIE 的优点结合一起。以下面 TextGraph 的场景,ERNIESage 的效果能够比单独的 ERNIE 以及 GraphSAGE 模型都要好。
ERNIESage 可以很轻松地在 PGL 中的消息传递范式中进行实现,目前 PGL 在 github 上提供了 3 个版本的 ERNIESage 模型:
ERNIESage v1: ERNIE 作用于 text graph 节点上;
ERNIESage v2: ERNIE 作用在 text graph 的边上;
ERNIESage v3: ERNIE 作用于一阶邻居及起边上;
主要会针对 ERNIESageV1 和 ERNIESageV2 版本进行一个介绍。
1.1 算法实现
可能有同学对于整个项目代码文件都不太了解,因此这里会做一个比较简单的讲解。
核心部分包含:
数据集部分
data.txt - 简单的输入文件,格式为每行 query \t answer,可作简单的运行实例使用。
模型文件和配置部分
ernie_config.json - ERNIE 模型的配置文件。
vocab.txt - ERNIE 模型所使用的词表。
ernie_base_ckpt/ - ERNIE 模型参数。
config/ - ERNIESage 模型的配置文件,包含了三个版本的配置文件。
代码部分
local_run.sh - 入口文件,通过该入口可完成预处理、训练、infer 三个步骤。
preprocessing 文件夹 - 包含 dump_graph.py, tokenization.py。在预处理部分,我们首先需要进行建图,将输入的文件构建成一张图。由于我们所研究的是 Text Graph,因此节点都是文本,我们将文本表示为该节点对应的 node feature(节点特征),处理文本的时候需要进行切字,再映射为对应的 token id。
dataset/ - 该文件夹包含了数据 ready 的代码,以便于我们在训练的时候将训练数据以 batch 的方式读入。
models/ - 包含了 ERNIESage 模型核心代码。
train.py - 模型训练入口文件。
learner.py - 分布式训练代码,通过 train.py 调用。
infer.py - infer 代码,用于 infer 出节点对应的 embedding。
评价部分
build_dev.py - 用于将我们的验证集修改为需要的格式。
mrr.py - 计算 MRR 值。
要在这个项目中运行模型其实很简单,只要运行下方的入口命令就 ok 啦!但是,需要注意的是,由于 ERNIESage 模型比较大,所以如果 AIStudio 中的 CPU 版本运行模型容易出问题。因此,在运行部署环境时,建议选择 GPU 的环境。
另外,如果提示出现了 GPU 空间不足等问题,我们可以通过调小对应 yaml 文件中的 batch_size 来调整,也可以修改 ERNIE 模型的配置文件 ernie_config.json,将 num_hidden_layers 设小一些。在这里,我仅提供了 ERNIESageV2 版本的 gpu 运行过程,如果同学们想运行其他版本的模型,可以根据需要修改下方的命令。
运行完毕后,会产生较多的文件,这里进行简单的解释。
workdir/ - 这个文件夹主要会存储和图相关的数据信息。
output/ - 主要的输出文件夹,包含了以下内容:(1)模型文件,根据 config 文件中的 save_per_step 可调整保存模型的频率,如果设置得比较大则可能训练过程中不会保存模型; (2)last 文件夹,保存了停止训练时的模型参数,在 infer 阶段我们会使用这部分模型参数;(3)part-0 文件,infer 之后的输入文件中所有节点的 Embedding 输出。
为了可以比较清楚地知道 Embedding 的效果,我们直接通过 MRR 简单判断一下 data.txt 计算出来的 Embedding 结果,此处将 data.txt 同时作为训练集和验证集。
1.2 核心模型代码讲解
首先,我们可以通过查看 models/model_factory.py 来判断在本项目有多少种 ERNIESage 模型。
可以看到一共有 ERNIESage 模型一共有 3 个版本,另外我们也提供了基本的 GNN 模型和 ERNIE 模型,感兴趣的同学可以自行查阅。
接下来,我主要会针对 ERNIESageV1 和 ERNIESageV2 这两个版本的模型进行关键部分的讲解,主要的不同其实就是消息传递机制(Message Passing)部分的不同。
1.2.1 ERNIESageV1 关键代码
通过上述代码片段可以看到,关键的消息传递机制代码就是 graphsage_sum 函数,其中 send、recv 部分如下。
通过代码可以看到,ERNIESageV1 版本,其主要是针对节点邻居,直接将当前节点的邻居节点特征求和。再看到 graphsage_sum 函数中,将邻居节点特征进行求和后,得到了 neigh_feature。随后,我们将节点本身的特征 self_feature 和邻居聚合特征 neigh_feature 通过 fc 层后,直接 concat 起来,从而得到了当前 gnn layer 层的 feature 输出。
1.2.2ERNIESageV2 关键代码
ERNIESageV2 的消息传递机制代码主要在 erniesage_v2.py 和 message_passing.py,相对 ERNIESageV1 来说,代码会相对长了一些。
为了使得大家对下面有关 ERNIE 模型的部分能够有所了解,这里先贴出 ERNIE 的主模型框架图。
具体的代码解释可以直接看注释。
2.总结
通过以上两个版本的模型代码简单的讲解,我们可以知道他们的不同点,其实主要就是在消息传递机制的部分有所不同。ERNIESageV1 版本只作用在 text graph 的节点上,在传递消息(Send 阶段)时只考虑了邻居本身的文本信息;而 ERNIESageV2 版本则作用在了边上,在 Send 阶段同时考虑了当前节点和其邻居节点的文本信息,达到更好的交互效果。
评论