Geneformer:基于 Transformer 的基因表达预测深度学习模型
摘要
Geneformer 被广泛应用于疾病建模、治疗靶点发掘、基因网络预测与调控分析、基因功能预测与剂量敏感性分析、单细胞转录组数据集成与标准化、遗传变异解释与 GWAS 靶点优先排序。该案例既有算法原理,也有手把手的昇腾部署教学,包含细胞分类、基因分类、提取细胞嵌入图、细胞多分类的微调任务
1 Geneformer 介绍
GeneFormer 是一种基于 Transformer 架构的深度学习模型,专为基因表达数据分析而设计。它将基因视为“词汇”,将整个基因组的表达谱视为“句子”,通过自监督学习捕捉基因间的复杂调控关系和生物学背景,在医学研究中展现出强大的应用潜力。借助 GeneFormer,研究人员能够更有效地处理和理解大量的基因组数据,从而加速新药开发、疾病治疗等领域的研究进展。在基因序列分析、蛋白质结构预测疾病机制解析和药物发现等领域也具有突出的应用价值。

图 1:自监督大规模预训练迁移学习策略示意图
初始自监督大规模预训练,将预训练权重复制到每个微调任务的模型中,添加微调层,并使用有限的特定任务数据对每个下游任务进行微调。通过对可泛化的学习目标进行单次初始自监督大规模预训练,模型获得学习领域的基础知识,然后将其推广到众多不同于预训练学习目标的下游应用中,将知识迁移到新任务。
2 网络结构

图 2:Geneformer 预训练架构图
预训练的 Geneformer 架构。每个单细胞转录组被编码为秩值编码,然后经过六层 Transformer 编码器单元,参数如下:输入大小为 2,048(完全代表 Geneformer-30M 中 93% 的秩值编码),嵌入维度为 256,每层四个注意力头,前馈大小为 512。Geneformer 在 2,048 的输入大小上使用完全密集的自注意力机制。可提取的输出包括上下文基因和细胞嵌入、上下文注意力权重以及上下文预测。
2.1 输入层
输入层针对基因表达数据的特性在数据预处理、嵌入表示(Embedding)和位置编码(Positional Encoding)等进行了专门优化。
数据预处理:
基因嵌入:对基因表达值进行归一化处理,消除不同基因表达水平之间的差异,并对缺失值进行合理填充或插值处理,以确保数据的完整性。
输入数据:通常包括基因表达矩阵(如单细胞 RNA 测序数据)和基因序列(如 DNA 序列)。基因表达矩阵是一个二维矩阵,其中行代表样本,列代表基因,每个元素代表对应基因在该样本中的表达值。基因序列则是由碱基 A、T、C、G 组成的字符串序列。
嵌入层:将基因表达值或基因序列映射到高维向量空间,以捕捉基因间的复杂关系,便于后续模型处理序列结构。维度设置需要根据具体任务和计算资源进行权衡,过低的维度可能导致信息丢失,而过高会增加计算复杂度。此外,嵌入层通常通过反向传播进行训练,使模型能够自动学习最优的基因嵌入表示,从而更好地适应任务需求。
位置编码:用于提供基因序列中各碱基的位置信息,帮助模型理解基因序列中碱基的顺序关系和位置依赖性,对于分析基因序列的功能和结构至关重要。
2.2Transformer 层
GeneFormer 的核心由多个 Transformer 层堆叠而来。通过多头自注意力、残差连接和前馈神经网络,从高维基因表达数据中提取复杂的调控模式。在保持标准的 Transformer 结构的同时,针对基因表达数据的特性(高维度、稀疏性、基因共表达模式)进行了优化,使模型能够有效捕捉基因间的功能关联,为下游任务(微调)提供强有力的表征。
多头注意力:并行使用多个注意力头,每个头学习不同的交互模式,同时计算多组注意力权重,捕捉基因间的全局依赖关系(如协同表达的基因网络)。通过计算查询(Query)、键(Key)和值(Value)之间的点积来确定权重,并通过 Softmax 函数进行归一化,且总和为 1。


将输入拆分为 h 个头,每个头单独计算后拼接。
前馈神经网络:由两层全连接层和激活函数组成,每个多头注意力层后接一个前馈神经网络层,对注意力层的输出进行非线性映射增强非线性表达能力,用于学习并保存知识。

稀疏注意力:基因表达数据中,大部分基因表达值为 0,可能采用局部稀疏注意力以降低计算开销。

相对位置编码:由于基因在序列中的物理位置可能无关紧要,Geneformer 采用相对位置编码,仅编码基因间的相对顺序或距离,增强对基因序列位置的敏感性。

i, j 为基因在序列中的位置,k 为最大相对距离。
层归一化与残差连接:层归一化稳定单细胞数据的高变异表达分布,残差连接保留原始基因表达信息,缓解梯度消失,加速收敛。

μ和σ分别为样本内均值和方差,γ和β分别为可学习的缩放和平移参数

2.3 输出层
经过 transformer 层之后,张量被传入输出层,但 Geneformer 输出层的设计根据具体任务(如基因表达预测、分类或自监督预训练)有所不同,主要操作通常包括以下几个关键步骤:
线性变换:使用全连接层,将 Transformer 最后一层输出的隐藏状态映射到目标维度(如基因数量或类别数)。

是 transformer 最后一层的输出向量, 是基因表达预测的权重矩阵, 为偏置项。
激活函数:根据任务需求不同调整使用的激活函数,回归任务可能使用 ReLU 或 Softplus 确保输出非负,分类任务使用 Softmax(多分类)或 Sigmoid(二分类)输出概率分布,对于线性输出,则没有激活函数。

损失计算:对于回归任务,使用均方误差(MSE)或负对数似然。分类任务,交叉熵损失。自监督任务(掩码基因预测),使用对比损失或遮蔽语言建模(MLM)类似的损失。

是被遮蔽基因的秩编码值(如基因表达量在细胞内的排序分位数)
细胞分类任务中的损失计算,交叉熵损失

基因扰动预测,对比损失

其中, 是基因敲除后的细胞表达谱,通过对比学习强化扰动前后的表达差异。
3 微调介绍
GeneFormer 先在大规模单细胞数据上预训练,结合特定任务的需求和数据特点,灵活选择冻结策略、调整输出头、引入适配器或领域特定模块。通过平衡预训练知识的保留与任务适配,高效实现模型优化。
网络结构的微调操作:
根据具体的下游任务,确定输入输出格式。即指定数据集。在输入层将数据预处理为与 GeneFormer 兼容的格式,加载预训练的 GeneFormer 权重。
选择冻结一定数目的 transformer 层,但不会全部冻结,会保留几层用于保留预训练模型的底层知识(如基因共现模式、 基础序列特征),防止小数据过拟合。
在预训练模型的基础上额外增加一个 transformer 层,用于学习新的知识。并在每一层插入小型适配器模块,保持预训练权重冻结,仅训练适配器参数,用于减少参数更新量,适用于小样本微调。
在输出层,也会根据具体的下游任务进行调整,仅训练最后一层 transformer 层及输出头。对于分类任务:替换最后的全局平均池化层 + 全连接层。回归任务:调整输出层为线性回归头。生成任务:添加解码器。
4 实验准备
4.1 设备 &组件
机器:
Atlas 800T A2
组件:
hdk:24.1.rc3

添加图片注释,不超过 140 字(可选)
cann:8.0.RC3

添加图片注释,不超过 140 字(可选)
python:3.10.16

添加图片注释,不超过 140 字(可选)
torch:2.1.0
torch:2.1.0.post8

添加图片注释,不超过 140 字(可选)
4.2 安装 LFS
4.3 下载源码
4.4 下载数据集

添加图片注释,不超过 140 字(可选)
4.5 安装环境
requirements.txt 里面 torch 的版本>=2.0.1 即可,这里选用 2.1.0 版本的 torch。
4.6 安装 torch-npu
4.6.1 下载
4.6.2 安装
4.6.3 验证 npu 是否可用
numpy 报错,需降低至 1.x 版本

添加图片注释,不超过 140 字(可选)

添加图片注释,不超过 140 字(可选)
更换完 numpy 版本之后,再次验证

添加图片注释,不超过 140 字(可选)
5 微调
5.1 微调 1:细胞分类
5.1.1 数据集 &权重任务
任务:cell_classification
数据集:human_dcm_hcm_nf.dataset
预训练权重:gf-6L-30M-i2048
5.1.2 新建微调脚本
将 cell_classification.ipynb 的代码复制过来。需注意修改权重路径和数据集路径。导入 os 包,将第 10 行的!mkdir $output_dir 修改为 os.makedirs(output_dir, exist_ok=True)

5.1.3 修改评估模型脚本
导入 torch_npu 包,替换相关 cuda 的 api

5.1.4 微调前 source cann
5.1.5 开始微调

再开一个窗口,命令行输入 npu-smi info 查看显存占用率


5.1.6 评估模型时报错


再重新运行

输出精度 0.9542330129066371

输出文件

混淆矩阵

评估微调模型的预测结果

5.2 微调 2:基因分类
5.2.1 数据集 &权重文件
任务:gene_classification
数据集:gc-30M_sample50k.dataset
预训练权重:gf-6L-30M-i2048
5.2.2 新建微调脚本
将 gene_classification.ipynb 的代码复制过来。需注意修改权重路径和数据集路径

5.2.3 开始微调


5.2.4 输出文件

5.3 微调 3:绘制细胞嵌入图
5.3.1 数据集 &权重文件
任务:extract_and_plot_cell_embeddings
数据集:human_dcm_hcm_nf.dataset
预训练权重:gf-6L-30M-i2048_CellClassifier_cardiomyopathies_220224
5.3.2 新建微调脚本
将 extract_and_plot_cell_embeddings.ipynb 的代码复制过来。需注意修改权重路径和数据集路径

5.3.3 开始微调

5.3.4 输出文件

5.3.5 细胞嵌入 UMAP 图

5.3.6 细胞嵌入 heapmap 图

5.4 微调 4:多任务细胞分类
5.4.1 数据集 &权重文件
任务:multitask_cell_classification
数据集:human_dcm_hcm_nf.dataset
预训练权重:gf-6L-30M-i2048
5.4.2 新建微调脚本
将 multitask_cell_classification.ipynb 的代码复制过来。需注意修改权重路径、数据集路径以及 token_dictionary 路径。

5.4.3 微调过程

5.4.4 输出

6 参考文献
Theodoris, C. V., Xiao, L., Chopra, A., Chaffin, M. D., Al Sayed, Z. R., Hill, M. C., ... & Ellinor, P. T. (2023). Transfer learning enables predictions in network biology. Nature, 618(7965), 616-624. https://doi.org/10.1038/s41586-023-06139-9
评论