写点什么

TextBrewer: 融合并改进了 NLP 和 CV 中的多种知识蒸馏技术、提供便捷快速的知识蒸馏框架、提升模型的推理速度,减少内存占用

  • 2023-08-07
    浙江
  • 本文字数:4195 字

    阅读完需:约 14 分钟

TextBrewer:融合并改进了NLP和CV中的多种知识蒸馏技术、提供便捷快速的知识蒸馏框架、提升模型的推理速度,减少内存占用

TextBrewer:融合并改进了 NLP 和 CV 中的多种知识蒸馏技术、提供便捷快速的知识蒸馏框架、提升模型的推理速度,减少内存占用

TextBrewer 是一个基于 PyTorch 的、为实现 NLP 中的知识蒸馏任务而设计的工具包,融合并改进了 NLP 和 CV 中的多种知识蒸馏技术,提供便捷快速的知识蒸馏框架,用于以较低的性能损失压缩神经网络模型的大小,提升模型的推理速度,减少内存占用。

1.简介

TextBrewer 为 NLP 中的知识蒸馏任务设计,融合了多种知识蒸馏技术,提供方便快捷的知识蒸馏框架。


主要特点:


  • 模型无关:适用于多种模型结构(主要面向 Transfomer 结构)

  • 方便灵活:可自由组合多种蒸馏方法;可方便增加自定义损失等模块

  • 非侵入式:无需对教师与学生模型本身结构进行修改

  • 支持典型的 NLP 任务:文本分类、阅读理解、序列标注等


TextBrewer 目前支持的知识蒸馏技术有:


  • 软标签与硬标签混合训练

  • 动态损失权重调整与蒸馏温度调整

  • 多种蒸馏损失函数: hidden states MSE, attention-based loss, neuron selectivity transfer, ...

  • 任意构建中间层特征匹配方案

  • 多教师知识蒸馏

  • ...


TextBrewer 的主要功能与模块分为 3 块:


  1. Distillers:进行蒸馏的核心部件,不同的 distiller 提供不同的蒸馏模式。目前包含 GeneralDistiller, MultiTeacherDistiller, MultiTaskDistiller 等

  2. Configurations and Presets:训练与蒸馏方法的配置,并提供预定义的蒸馏策略以及多种知识蒸馏损失函数

  3. Utilities:模型参数分析显示等辅助工具


用户需要准备:


  1. 已训练好的教师模型, 待蒸馏的学生模型

  2. 训练数据与必要的实验配置, 即可开始蒸馏


在多个典型 NLP 任务上,TextBrewer 都能取得较好的压缩效果。相关实验见蒸馏效果

2.TextBrewer 结构

2.1 安装要求

  • Python >= 3.6

  • PyTorch >= 1.1.0

  • TensorboardX or Tensorboard

  • NumPy

  • tqdm

  • Transformers >= 2.0 (可选, Transformer 相关示例需要用到)

  • Apex == 0.1.0 (可选,用于混合精度训练)

  • 从 PyPI 自动下载安装包安装:


pip install textbrewer
复制代码


  • 从源码文件夹安装:


git clone https://github.com/airaria/TextBrewer.gitpip install ./textbrewer
复制代码

2.2 工作流程



  • Stage 1 : 蒸馏之前的准备工作:

  • 训练教师模型

  • 定义与初始化学生模型(随机初始化,或载入预训练权重)

  • 构造蒸馏用数据集的 dataloader,训练学生模型用的 optimizer 和 learning rate scheduler

  • Stage 2 : 使用 TextBrewer 蒸馏:

  • 构造训练配置(TrainingConfig)和蒸馏配置(DistillationConfig),初始化 distiller

  • 定义 adaptorcallback ,分别用于适配模型输入输出和训练过程中的回调

  • 调用 distiller train 方法开始蒸馏

2.3 以蒸馏 BERT-base 到 3 层 BERT 为例展示 TextBrewer 用法

在开始蒸馏之前准备:


  • 训练好的教师模型teacher_model (BERT-base),待训练学生模型student_model (3-layer BERT)

  • 数据集dataloader,优化器optimizer,学习率调节器类或者构造函数scheduler_class 和构造用的参数字典 scheduler_args


使用 TextBrewer 蒸馏:


import textbrewerfrom textbrewer import GeneralDistillerfrom textbrewer import TrainingConfig, DistillationConfig
#展示模型参数量的统计print("\nteacher_model's parametrers:")result, _ = textbrewer.utils.display_parameters(teacher_model,max_level=3)print (result)
print("student_model's parametrers:")result, _ = textbrewer.utils.display_parameters(student_model,max_level=3)print (result)
#定义adaptor用于解释模型的输出def simple_adaptor(batch, model_outputs): # model输出的第二、三个元素分别是logits和hidden states return {'logits': model_outputs[1], 'hidden': model_outputs[2]}
#蒸馏与训练配置# 匹配教师和学生的embedding层;同时匹配教师的第8层和学生的第2层distill_config = DistillationConfig( intermediate_matches=[ {'layer_T':0, 'layer_S':0, 'feature':'hidden', 'loss': 'hidden_mse','weight' : 1}, {'layer_T':8, 'layer_S':2, 'feature':'hidden', 'loss': 'hidden_mse','weight' : 1}])train_config = TrainingConfig()
#初始化distillerdistiller = GeneralDistiller( train_config=train_config, distill_config = distill_config, model_T = teacher_model, model_S = student_model, adaptor_T = simple_adaptor, adaptor_S = simple_adaptor)
#开始蒸馏with distiller: distiller.train(optimizer, dataloader, num_epochs=1, scheduler_class=scheduler_class, scheduler_args = scheduler_args, callback=None)
复制代码

2.4 蒸馏任务示例

2.4.1 蒸馏效果

我们在多个中英文文本分类、阅读理解、序列标注数据集上进行了蒸馏实验。实验的配置和效果如下。



我们测试了不同的学生模型,为了与已有公开结果相比较,除了 BiGRU 都是和 BERT 一样的多层 Transformer 结构。模型的参数如下表所示。需要注意的是,参数量的统计包括了 embedding 层,但不包括最终适配各个任务的输出层。


  • 英文模型



  • 中文模型



2.4.2 蒸馏配置

distill_config = DistillationConfig(temperature = 8, intermediate_matches = matches)#其他参数为默认值
复制代码


不同的模型用的matches我们采用了以下配置:



各种 matches 的定义在examples/matches/matches.py中。均使用 GeneralDistiller 进行蒸馏。

2.4.3 训练配置

蒸馏用的学习率 lr=1e-4(除非特殊说明)。训练 30~60 轮。

2.4.4 英文实验结果

在英文实验中,我们使用了如下三个典型数据集。



我们在下面两表中列出了DistilBERT, BERT-PKD, BERT-of-Theseus, TinyBERT 等公开的蒸馏结果,并与我们的结果做对比。


Public results:



Our results:



说明:


  1. 公开模型的名称后括号内是其等价的模型结构

  2. 蒸馏到 T4-tiny 的实验中,SQuAD 任务上使用了 NewsQA 作为增强数据;CoNLL-2003 上使用了 HotpotQA 的篇章作为增强数据

  3. 蒸馏到 T12-nano 的实验中,CoNLL-2003 上使用了 HotpotQA 的篇章作为增强数据

2.4.5 中文实验结果

在中文实验中,我们使用了如下典型数据集。



实验结果如下表所示。




说明:


  1. 以 RoBERTa-wwm-ext 为教师模型蒸馏 CMRC 2018 和 DRCD 时,不采用学习率衰减

  2. CMRC 2018 和 DRCD 两个任务上蒸馏时他们互作为增强数据

  3. Electra-base 的教师模型训练设置参考自Chinese-ELECTRA

  4. Electra-small 学生模型采用预训练权重初始化

3.核心概念

3.1Configurations

  • TrainingConfigDistillationConfig:训练和蒸馏相关的配置。

3.2Distillers

Distiller 负责执行实际的蒸馏过程。目前实现了以下的 distillers:


  • BasicDistiller: 提供单模型单任务蒸馏方式。可用作测试或简单实验。

  • GeneralDistiller (常用): 提供单模型单任务蒸馏方式,并且支持中间层特征匹配,一般情况下推荐使用

  • MultiTeacherDistiller: 多教师蒸馏。将多个(同任务)教师模型蒸馏到一个学生模型上。暂不支持中间层特征匹配

  • MultiTaskDistiller:多任务蒸馏。将多个(不同任务)单任务教师模型蒸馏到一个多任务学生模型。

  • BasicTrainer:用于单个模型的有监督训练,而非蒸馏。可用于训练教师模型

3.3 用户定义函数

蒸馏实验中,有两个组件需要由用户提供,分别是 callbackadaptor :

3.3.1Callback

回调函数。在每个 checkpoint,保存模型后会被distiller调用,并传入当前模型。可以借由回调函数在每个 checkpoint 评测模型效果。

3.3.2Adaptor

将模型的输入和输出转换为指定的格式,向distiller解释模型的输入和输出,以便distiller根据不同的策略进行不同的计算。在每个训练步,batch和模型的输出model_outputs会作为参数传递给adaptoradaptor负责重新组织这些数据,返回一个字典。


更多细节可参见完整文档中的说明。

4.FAQ

Q: 学生模型该如何初始化?


A: 知识蒸馏本质上是“老师教学生”的过程。在初始化学生模型时,可以采用随机初始化的形式(即完全不包含任何先验知识),也可以载入已训练好的模型权重。例如,从 BERT-base 模型蒸馏到 3 层 BERT 时,可以预先载入RBT3模型权重(中文任务)或 BERT 的前三层权重(英文任务),然后进一步进行蒸馏,避免了蒸馏过程的“冷启动”问题。我们建议用户在使用时尽量采用已预训练过的学生模型,以充分利用大规模数据预训练所带来的优势。


Q: 如何设置蒸馏的训练参数以达到一个较好的效果?


A: 知识蒸馏的比有标签数据上的训练需要更多的训练轮数与更大的学习率。比如,BERT-base 上训练 SQuAD 一般以 lr=3e-5 训练 3 轮左右即可达到较好的效果;而蒸馏时需要以 lr=1e-4 训练 30~50 轮。当然具体到各个任务上肯定还有区别,我们的建议仅是基于我们的经验得出的,仅供参考


Q: 我的教师模型和学生模型的输入不同(比如词表不同导致 input_ids 不兼容),该如何进行蒸馏?


A: 需要分别为教师模型和学生模型提供不同的 batch,参见完整文档中的 Feed Different batches to Student and Teacher, Feed Cached Values 章节。


Q: 我缓存了教师模型的输出,它们可以用于加速蒸馏吗?


A: 可以, 参见完整文档中的 Feed Different batches to Student and Teacher, Feed Cached Values 章节。



更多优质内容分享请关注公号:汀丶人工智能;会提供一些相关的资源和优质文章,免费获取阅读。

发布于: 刚刚阅读数: 6
用户头像

本博客将不定期更新关于NLP等领域相关知识 2022-01-06 加入

本博客将不定期更新关于机器学习、强化学习、数据挖掘以及NLP等领域相关知识,以及分享自己学习到的知识技能,感谢大家关注!

评论

发布
暂无评论
TextBrewer:融合并改进了NLP和CV中的多种知识蒸馏技术、提供便捷快速的知识蒸馏框架、提升模型的推理速度,减少内存占用_人工智能_汀丶人工智能_InfoQ写作社区