写点什么

PGL 图学习之基于 GNN 模型新冠疫苗任务 [系列九]

作者:汀丶
  • 2022-11-29
    浙江
  • 本文字数:2837 字

    阅读完需:约 9 分钟

PGL 图学习之基于 GNN 模型新冠疫苗任务[系列九]

项目链接:https://aistudio.baidu.com/aistudio/projectdetail/5123296?contributionType=1


# 加载一些需要用到的模块,设置随机数import jsonimport randomimport numpy as npimport pandas as pd
import matplotlib.pyplot as pltimport networkx as nx
from utils.config import prepare_config, make_dirfrom utils.logger import prepare_logger, log_to_filefrom data_parser import GraphParser
seed = 123np.random.seed(seed)random.seed(seed)
复制代码

数据 EDA

# https://www.kaggle.com/c/stanford-covid-vaccine/data# 加载训练用的数据df = pd.read_json('../data/data179441/train.json', lines=True)# 查看一下数据集的内容sample = df.loc[0]print(sample)
index 400id id_2a7a4496fsequence GGAAAGCCCGCGGCGCCGGGCGCCGCGGCCGCCCAGGCCGCCCGGC...structure .....(((...)))((((((((((((((((((((.((((....)))...predicted_loop_type EEEEESSSHHHSSSSSSSSSSSSSSSSSSSSSSSISSSSHHHHSSS...signal_to_noise 0SN_filter 0seq_length 107seq_scored 68reactivity_error [146151.225, 146151.225, 146151.225, 146151.22...deg_error_Mg_pH10 [104235.1742, 104235.1742, 104235.1742, 104235...deg_error_pH10 [222620.9531, 222620.9531, 222620.9531, 222620...deg_error_Mg_50C [171525.3217, 171525.3217, 171525.3217, 171525...deg_error_50C [191738.0886, 191738.0886, 191738.0886, 191738...reactivity [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...deg_Mg_pH10 [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...deg_pH10 [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...deg_Mg_50C [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...deg_50C [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...Name: 0, dtype: object
复制代码


例如 deg_50C、deg_Mg_50C 这样的值全为 0 的行,就是我们需要预测的。


structure 一行,数据中的括号是为了构成边用的。


本案例要预测 RNA 序列不同位置的降解速率,训练数据中提供了多个 ground 值,标签包括以下几项:reactivity, deg_Mg_pH10, and deg_Mg_50


reactivity - (1x68 vector 训练集,1x91 测试集) 一个浮点数数组,与 seq_scores 有相同的长度,是前 68 个碱基的反应活性值,按顺序表示,用于确定 RNA 样本可能的二级结构。


deg_Mg_pH10 - (训练集 1x68 向量,1x91 测试集)一个浮点数数组,与 seq_scores 有相同的长度,是前 68 个碱基的反应活性值,按顺序表示,用于确定在高 pH (pH 10)下的降解可能性。


deg_Mg_50 - (训练集 1x68 向量,1x91 测试集)一个浮点数数组,与 seq_scores 有相同的长度,是前 68 个碱基的反应活性值,按顺序表示,用于确定在高温(50 摄氏度)下的降解可能性。


# 利用GraphParser构造图结构的数据args = prepare_config("./config.yaml", isCreate=False, isSave=False)parser = GraphParser(args) # GraphParser类来自data_parser.pygdata = parser.parse(sample) # GraphParser里最主要的函数就是parse(self, sample)
复制代码


{'nfeat': array([[0., 0., 0., ..., 0., 0., 0.],        [0., 0., 0., ..., 0., 0., 0.],        [0., 1., 0., ..., 0., 0., 0.],        ...,        [1., 0., 0., ..., 0., 0., 0.],        [1., 0., 0., ..., 0., 0., 0.],        [1., 0., 0., ..., 0., 0., 0.]], dtype=float32), 'edges': array([[  0,   1],        [  1,   0],        [  1,   2],        ...,        [142, 105],        [106, 142],        [142, 106]]), 'efeat': array([[ 0.,  0.,  0.,  1.,  1.],        [ 0.,  0.,  0., -1.,  1.],        [ 0.,  0.,  0.,  1.,  1.],        ...,        [ 0.,  1.,  0.,  0.,  0.],        [ 0.,  1.,  0.,  0.,  0.],        [ 0.,  1.,  0.,  0.,  0.]], dtype=float32), 'labels': array([[ 0.    ,  0.    ,  0.    ],        [ 0.    ,  0.    ,  0.    ],        ...,        [ 0.    ,  0.9213,  0.    ],        [ 6.8894,  3.5097,  5.7754],        [ 0.    ,  1.8426,  6.0642],          ...,                [ 0.    ,  0.    ,  0.    ],        [ 0.    ,  0.    ,  0.    ]], dtype=float32), 'mask': array([[ True],        [ True],     ......       [False]])}
复制代码


nfeat —— 节点特征


edges —— 边


efeat —— 边特征


labels —— 节点标签有三种,所以这可以看成是一个多分类任务

图数据可视化

# 图数据可视化fig = plt.figure(figsize=(24, 12))nx_G = nx.Graph()nx_G.add_nodes_from([i for i in range(len(gdata['nfeat']))])
nx_G.add_edges_from(gdata['edges'])node_color = ['g' for _ in range(sample['seq_length'])] + \['y' for _ in range(len(gdata['nfeat']) - sample['seq_length'])]options = { "node_color": node_color,}pos = nx.spring_layout(nx_G, iterations=400, k=0.2)nx.draw(nx_G, pos, **options)
plt.show()
复制代码


模型训练 &预测

# 我们在 layer.py 里定义了一个新的 gnn 模型(my_gnn),消息传递的过程中加入了边的特征(edge_feat)# 然后修改 model.py 里的 GNNModel# 使用修改后的模型,运行 main.py。为节省时间,设置 epochs = 100
!python main.py --config config.yaml
复制代码


结果返回的是 MCRMSE 和 loss


{'MCRMSE': 0.5496759, 'loss': 0.3025484172316889}


[DEBUG] 2022-11-25 17:50:42,468 [  trainer.py:   66]:  {'MCRMSE': 0.5496759, 'loss': 0.3025484172316889}[DEBUG] 2022-11-25 17:50:42,468 [  trainer.py:   73]:  write to tensorboard ../checkpoints/covid19/eval_history/eval[DEBUG] 2022-11-25 17:50:42,469 [  trainer.py:   73]:  write to tensorboard ../checkpoints/covid19/eval_history/eval[INFO] 2022-11-25 17:50:42,469 [  trainer.py:   76]:  [Eval:eval]:MCRMSE:0.5496758818626404  loss:0.3025484172316889[INFO] 2022-11-25 17:50:42,602 [monitored_executor.py:  606]:  ********** Stop Loop ************[DEBUG] 2022-11-25 17:50:42,607 [monitored_executor.py:  199]:  saving step 12500 to ../checkpoints/covid19/model_12500
复制代码


!python main.py --mode infer
复制代码


发布于: 刚刚阅读数: 3
用户头像

汀丶

关注

本博客将不定期更新关于NLP等领域相关知识 2022-01-06 加入

本博客将不定期更新关于机器学习、强化学习、数据挖掘以及NLP等领域相关知识,以及分享自己学习到的知识技能,感谢大家关注!

评论

发布
暂无评论
PGL图学习之基于GNN模型新冠疫苗任务[系列九]_图神经网络_汀丶_InfoQ写作社区