写点什么

在数据增强、蒸馏剪枝下 ERNIE3.0 分类模型性能提升

作者:汀丶
  • 2022-11-02
    浙江
  • 本文字数:6968 字

    阅读完需:约 23 分钟

在数据增强、蒸馏剪枝下 ERNIE3.0 模型性能提升

项目链接:https://aistudio.baidu.com/aistudio/projectdetail/4436131?contributionType=1


以 CBLUE 数据集中医疗搜索检索词意图分类为例:


本项目首先讲解了数据增强和数据蒸馏的方案,并在后面章节进行效果展示,结果预览:



gensim 安装最新版本:pip install gensim


tqdm 安装:pip install tqdm


LAC 安装最新版本:pip install lac




Gensim 库介绍


Gensim 是在做自然语言处理时较为经常用到的一个工具库,主要用来以无监督的方式从原始的非结构化文本当中来学习到文本隐藏层的主题向量表达。


主要包括 TF-IDF,LSA,LDA,word2vec,doc2vec 等多种模型。


Tqdm


是一个快速,可扩展的 Python 进度条,可以在 Python 长循环中添加一个进度提示信息,用户只需要封装任意的迭代器 tqdm(iterator)。目的为了程序显示的美观


中文词法分析-LAC


LAC 是一个联合的词法分析模型,整体性地完成中文分词、词性标注、专名识别任务。LAC 既可以认为是 Lexical Analysis of Chinese 的首字母缩写,也可以认为是 LAC Analyzes Chinese 的递归缩写。


LAC 基于一个堆叠的双向 GRU 结构,在长文本上准确复刻了百度 AI 开放平台上的词法分析算法。效果方面,分词、词性、专名识别的整体准确率 95.5%;单独评估专名识别任务,F 值 87.1%(准确 90.3,召回 85.4%),总体略优于开放平台版本。在效果优化的基础上,LAC 的模型简洁高效,内存开销不到 100M,而速度则比百度 AI 开放平台提高了 57%


LAC 链接:https://www.paddlepaddle.org.cn/modelbasedetail/lac


!pip install --upgrade paddlenlp!pip install gensim!pip install tqdm!pip install lac

2.数据增强方案介绍

数据增强工具提供 4 种增强策略:遮盖、删除、同词性词替换、词向量近义词替换



!unzip ERNIE-.zip -d ./ERNIE#添加 ERNIE 工具包


如果程序报错:可以发现提示有一个.ipynb_checkpoints的文件。但当我去对应的文件夹找时根本看不到这个文件,所以猜测是一个隐藏文件。所以通过终端进入对应的目录:输入cd coco进入对应目录,输入ls -a显示所有文件。然后输入rm -rf .ipynb_checkpoints删除该文件。再次输入ls -a查看文件是否被删除。
复制代码


下载词表,词表有 1.7G 会花点时间。下面以情感分析数据样例展示 demo,看看数据增强的效果。


!wget -q --no-check-certificate http://bj.bcebos.com/wenxin-models/vec2.txt
复制代码


python data_aug.py "输入文件夹的目录" "输出文件夹的目录"


  • data_aug.py 脚本传参说明


shell输入:    python data_aug.py -h
shell输出: usage: data_aug.py [-h] [-n AUG_TIMES] [-c COLUMN_NUMBER] [-u UNK] [-t TRUNCATE] [-r POS_REPLACE] [-w W2V_REPLACE] [-e ERNIE_REPLACE] [--unk_token UNK_TOKEN] input output main positional arguments: input #原始待增强数据文件所在文件夹,带label的,一个或多个文本列 output #输出文件路径 optional arguments: -h, --help show this help message and exit -n AUG_TIMES, --aug_times AUG_TIMES #数据集数目放大n倍,output行数为input的n+1倍 -c COLUMN_NUMBER, --column_number COLUMN_NUMBER #明文文件中所要增强列的列序号,多列用逗号分割,如:1,2 -u UNK, --unk UNK #unk 增强策略的概率 -t TRUNCATE, --truncate TRUNCATE #truncate 增强策略的概率 -r POS_REPLACE, --pos_replace POS_REPLACE #pos_replace 增强策略的概率 -w W2V_REPLACE, --w2v_replace W2V_REPLACE #w2v_replace 增强策略的概率 --unk_token UNK_TOKEN
复制代码


分类问题中:推荐使用前三种即可,w2v 词向量近义词替换可以不用,花费时间太长。


!python data_aug.py --unk 0.25 --truncate 0.25 --pos 0.5 --w2v 0 ./train ./output
复制代码


demo结果展示:
机器 背面 似乎 被 撕 了 张 什么 标签 , 残 胶 还在 。 但是 又 看 不 出 是 什么 标签 不见 了 , 该 有 的 都 在 , 怪 0机器 背面 似乎 被 撕 了 张 什么 标签 , 胶 还在 。 但是 又 看 不 出 是 什么 标签 不见 了 , 该 有 的 都 在 , 怪 0机器 背面 了 张 什么 标签 , 残 胶 还在 。 但是 又 看 不 出 是 什么 标签 了 , 该在 , 怪 0呵呵 , 虽然 表皮 看上去 不错 很 精致 , 但是 我 还是 能 看得出来 是 盗 的 。 但是 里面 的 内容 真 的 不错 , 我 妈 爱 看 , 我自己 也 学 着 找 一些 穴位 。 0呵呵 , 虽然 表皮 看上去 不错 很 精致 , 但是 我 还是 能 看得出来 是 盗 的 。 但是 里面 的 内容 真 的 不错 , 我😄妈 爱 看 , 我自己 也 学 着 找 一些 穴位 😄 0呵呵 , 虽然 表皮 看上去 不错 很 精致 , 但是 我 还😄 能 看得出来 是 盗😄😄😄。 但是 里面 的 内容 真 的 不错 , 我 妈 爱 看 ,😄😄😄😄😄😄😄学 着 找 😄😄😄😄😄😄😄 0😄😄😄😄😄虽然 表皮 看上去 不错 很 精致 , 但是 我 还是 能 看得出来 是 盗 的 。 但是 里面 的 内容 真 的 不错 , 我 妈 爱 看 , 我自己 也 学 着 找 一些 穴位 。 0😄😄😄😄😄😄😄 表皮 看上去 不错 很 精致 , 但是 我 还是 能 看得出来 是 盗 的 。 但是 里面 的 内容 真 的 不错 , 我 妈 爱 看 , 我自己 也 学 着 找 一些 穴位 。 0地理 位置 佳 , 在 市中心 。 酒店 服务 好 、 早餐 品种 丰富 。 我 住 的 商务 数码 房 电脑 宽带 速度 满意 , 房间 还算 干净 , 离 湖南路小吃街 近 。 1地理 位置 佳 , 在 市中心 。 酒店 服务 好 、 早餐 品种 丰富 。 我 住 的 商务 数码 房 电脑 宽带 速度 满意 , 房间 还算 干净 , 离 湖南路小吃街 近。。 1地理 位置 佳 , 在 市中心 。 酒店 服务 好 、 早餐 品种 丰富 。 我 住 的 商务 数码 房 电脑 宽带 速度 满意 , 机器 还算 干净 , 离 湖南路小吃街 近 。 1地理 位置 佳 , 在 市中心 。 酒店 服务 好 、 早餐 品种 丰富 。 我 住 的 商务 数码 房 电脑 宽带 速度 满意 , 房间 还算 干净 , 离 湖南路小吃街 近 。 1地理 位置 佳 , 在 市中心 。 酒店 服务 好 、 早餐 品种 丰富 。 我 住 的 商务 数码 房 电脑 宽我 看 是 书 的 还 可以 , 但是 我 订 的 书 迟迟 还 到 能 半个月 , 都 没有 收到 打电话 也 没
复制代码

2.0 补充 nlpcda 一键中文数据增强工具(NLP Chinese Data Augmentation )

一键中文数据增强工具,支持:


1.随机实体替换 2.近义词 3.近义近音字替换 4.随机字删除(内部细节:数字时间日期片段,内容不会删)5.NER 类 BIO 数据增强 6.随机置换邻近的字:研表究明,汉字序顺并不定一影响文字的阅读理解<<是乱序的 7.中文等价字替换(1 一 壹 ①,2 二 贰 ②)8.翻译互转实现的增强 9.使用 simbert 做生成式相似句生成


参考链接:一键中文数据增强包 ; NLP数据增强、bert数据增强、EDA:pip install nlpcdanlpcda一键中文数据增强工具

3.数据蒸馏技术

ERNIE 数据蒸馏三步

Step 1. 使用 ERNIE 模型对输入标注数据对进行 fine-tune,得到 Teacher Model


Step 2. 使用 ERNIE Service 对以下无监督数据进行预测:


  • 用户提供的大规模无标注数据,需与标注数据同源

  • 对标注数据进行数据增强,具体增强策略

  • 对无标注数据和数据增强数据进行一定比例混合


Step 3. 使用步骤 2 的数据训练出 Student Model


数据增强


目前采用三种数据增强策略策略,对于不用的任务可以特定的比例混合。三种数据增强策略包括:


添加噪声:对原始样本中的词,以一定的概率(如 0.1)替换为”UNK”标签


同词性词替换:对原始样本中的所有词,以一定的概率(如 0.1)替换为本数据集钟随机一个同词性的词


N-sampling:从原始样本中,随机选取位置截取长度为 m 的片段作为新的样本,其中片段的长度 m 为 0 到原始样本长度之间的随机值



模型剪裁,基于 PaddleNLP 的 Trainer API 发布提供了模型裁剪 API。裁剪 API 支持用户对 ERNIE 等 Transformers 类下游任务微调模型进行裁剪。


具体效果在下一节展现,先安装好 paddleslim 库

4.基于 ERNIR3.0 文本模型微调

加载已有数据集:CBLUE 数据集中医疗搜索检索词意图分类(训练)


数据集定义:以公开数据集 CBLUE 数据集中医疗搜索检索词意图分类(KUAKE-QIC)任务为示例,在训练集上进行模型微调,并在开发集上使用准确率 Accuracy 评估模型表现。


数据集默认为:默认为"cblue"。


save_dir:保存训练模型的目录;默认保存在当前目录 checkpoint 文件夹下。


dataset:训练数据集;默认为"cblue"。


<font color="red">dataset_dir:本地数据集路径,数据集路径中应包含 train.txt,dev.txt 和 label.txt 文件;默认为 None。</font>


task_name:训练数据集;默认为"KUAKE-QIC"。


max_seq_length:ERNIE 模型使用的最大序列长度,最大不能超过 512, 若出现显存不足,请适当调低这一参数;默认为 128。


<font color="red">model_name:选择预训练模型;默认为"ernie-3.0-base-zh"。</font>


<font color="red">device: 选用什么设备进行训练,可选 cpu、gpu、xpu、npu。如使用 gpu 训练,可使用参数 gpus 指定 GPU 卡号。</font>


batch_size:批处理大小,请结合显存情况进行调整,若出现显存不足,请适当调低这一参数;默认为 32。


learning_rate:Fine-tune 的最大学习率;默认为 6e-5。


weight_decay:控制正则项力度的参数,用于防止过拟合,默认为 0.01。


early_stop:选择是否使用早停法(EarlyStopping);默认为 False。


<font color="red">early_stop_nums:在设定的早停训练轮次内,模型在开发集上表现不再上升,训练终止;默认为 4。epochs: 训练轮次,默认为 100。</font>


warmup:是否使用学习率 warmup 策略;默认为 False。


warmup_proportion:学习率 warmup 策略的比例数,如果设为 0.1,则学习率会在前 10%steps 数从 0 慢慢增长到 learning_rate, 而后再缓慢衰减;默认为 0.1。


logging_steps: 日志打印的间隔 steps 数,默认 5。


init_from_ckpt: 模型初始 checkpoint 参数地址,默认 None。


seed:随机种子,默认为 3。


#修改后的训练文件train_new2.py ,主要使用了paddlenlp.metrics.glue的AccuracyAndF1:准确率及F1-score,可用于GLUE中的MRPC 和QQP任务#不过吐槽一下:    return (acc,precision,recall,f1,(acc + f1) / 2,) 最后一个指标竟然是加权平均.....!python train_new2.py --warmup --early_stop --epochs 10 --save_dir "./checkpoint2" --batch_size 16 --model_name ernie-3.0-base-zh
复制代码


训练结果部分展示:


[2022-08-16 19:58:36,834] [    INFO] - global step 1280, epoch: 3, batch: 412, loss: 0.23292, acc: 0.87106, speed: 16.54 step/s[2022-08-16 19:58:37,392] [    INFO] - global step 1290, epoch: 3, batch: 422, loss: 0.22339, acc: 0.87130, speed: 17.94 step/s[2022-08-16 19:58:37,960] [    INFO] - global step 1300, epoch: 3, batch: 432, loss: 0.22791, acc: 0.87182, speed: 17.68 step/s(acc, precision, recall, f1, average_of_acc_and_f1):(0.8025575447570332, 0.9317147192716236, 0.908284023668639, 0.9198501872659175, 0.8612038660114754)
复制代码


[2022-08-16 20:01:36,060] [ INFO] - Early stop![2022-08-16 20:01:36,060] [ INFO] - Save best accuracy text classification model in ./checkpoint2

4.1 加载自定义数据集(并通过数据增强训练)

从本地文件创建数据集


使用本地数据集来训练我们的文本分类模型,本项目支持使用固定格式本地数据集文件进行训练如果需要对本地数据集进行数据标注,可以参考文本分类任务 doccano 数据标注使用指南进行文本分类数据标注。[这个放到下个项目讲解]


本项目将以 CBLUE 数据集中医疗搜索检索词意图分类(KUAKE-QIC)任务为例进行介绍如何加载本地固定格式数据集进行训练:


本地数据集目录结构如下:


data/├── train.txt # 训练数据集文件├── dev.txt # 开发数据集文件├── label.txt # 分类标签文件└── data.txt # 可选,待预测数据文件
复制代码


部分结果展示


[2022-08-16 23:43:18,093] [    INFO] - global step 2400, epoch: 2, batch: 234, loss: 0.60859, acc: 0.84437, speed: 19.27 step/s(acc, precision, recall, f1, average_of_acc_and_f1):(0.7979539641943734, 0.9010043041606887, 0.9289940828402367, 0.9147851420247632, 0.8563695531095683)[2022-08-16 23:43:24,522] [    INFO] - Save best F1 text classification model in ./checkpoint3[2022-08-16 23:43:24,523] [    INFO] - best F1 performence has been updated: 0.91450 --> 0.91479
复制代码

4.2 数据蒸馏

!unset CUDA_VISIBLE_DEVICES!python -m paddle.distributed.launch --gpus "0" prune.py \    --device "gpu" \    --output_dir "./prune" \    --per_device_train_batch_size 32 \    --per_device_eval_batch_size 32 \    --learning_rate 3e-5 \    --num_train_epochs 5 \    --logging_steps 10 \    --save_steps 50 \    --seed 3 \    --dataset_dir "KUAKE_QIC" \    --max_seq_length 128 \    --params_dir "./checkpoint3" \    --width_mult '0.5'
复制代码


部分结果展示:


[2022-08-17 14:22:30,954] [    INFO] - width_mult: 0.5, eval loss: 0.63535, acc: 0.79847(acc, precision, recall, f1, average_of_acc_and_f1):(0.7984654731457801, 0.9512578616352201, 0.8949704142011834, 0.9222560975609755, 0.8603607853533778)[2022-08-17 14:22:35,870] [    INFO] - Save best F1 text classification model in ./prune/0.5[2022-08-17 14:22:35,870] [    INFO] - best F1 performence has been updated: 0.92226 --> 0.92226
复制代码


!unset CUDA_VISIBLE_DEVICES!python -m paddle.distributed.launch --gpus "0" prune.py \    --device "gpu" \    --output_dir "./prune" \    --per_device_train_batch_size 32 \    --per_device_eval_batch_size 32 \    --learning_rate 3e-5 \    --num_train_epochs 5 \    --logging_steps 10 \    --save_steps 50 \    --seed 3 \    --dataset_dir "KUAKE_QIC" \    --max_seq_length 128 \    --params_dir "./checkpoint3" \    --width_mult '2/3'
复制代码


2022-08-17 14:53:45,544] [    INFO] - global step 3070, epoch: 2, batch: 904, loss: 0.709566, speed: 9.93 step/s[2022-08-17 14:53:46,550] [    INFO] - global step 3080, epoch: 2, batch: 914, loss: 0.607238, speed: 9.94 step/s[2022-08-17 14:53:47,558] [    INFO] - global step 3090, epoch: 2, batch: 924, loss: 0.718484, speed: 9.93 step/s[2022-08-17 14:53:48,563] [    INFO] - global step 3100, epoch: 2, batch: 934, loss: 0.546288, speed: 9.95 step/s[2022-08-17 14:53:50,206] [    INFO] - teacher model, eval loss: 0.66438, acc: 0.80358[2022-08-17 14:53:50,207] [    INFO] - eval done total : 1.6434180736541748 s[2022-08-17 14:53:53,568] [    INFO] - width_mult: 0.6666666666666666, eval loss: 0.60219, acc: 0.80921(acc, precision, recall, f1, average_of_acc_and_f1):(0.8092071611253197, 0.9415384615384615, 0.9053254437869822, 0.923076923076923, 0.8661420421011213)[2022-08-17 14:53:58,489] [    INFO] - Save best F1 text classification model in ./prune/0.6666666666666666[2022-08-17 14:53:58,489] [    INFO] - best F1 performence has been updated: 0.92308 --> 0.92308
复制代码

4.3 模型预测

输入待预测数据和数据标签对照列表,模型预测数据对应的标签


使用默认数据进行预测:


#也可以选择使用本地数据文件data/data.txt进行预测:!python predict.py --params_path ./checkpoint3/ --dataset_dir ./KUAKE_QIC --device "cpu"
复制代码


黑苦荞茶的功效与作用及食用方法 功效作用交界痣会凸起吗 疾病表述检查是否能怀孕挂什么科 就医建议鱼油怎么吃咬破吃还是直接咽下去 其他幼儿挑食的生理原因是 病因分析
复制代码


!python predict.py \    --device "cpu" \    --dataset_dir ./KUAKE_QIC \    --params_path "./prune/0.5" \
复制代码

5.总结

本项目首先讲解了数据增强和数据蒸馏的方案,并在后面章节进行效果展示,现在进行汇总



分析可得,


  • 首先数据增强后导致性能部分下降部分和预期的原因:随机 mask、删除会产生过多噪声样本影响结果,推荐只使用同义词替换,本次样本数据量足够,且 ERNIE 性能本就优越,数据增强对结果提升在较大样本集可以忽略。

  • 其次,可以看到通过数据蒸馏后,模型性能变化不大,甚至在剪裁 1/3 之后,性能有小幅度提升


本次主要对分类模型加入数据增强、数据蒸馏,已经对性能指标进行细化,不只是 ACC,个人比较关注 F1 情况,并作为保存模型依据。


展望: 后续将完善动态图和静态图转化部分,让蒸馏下来模型可以继续线上加载使用;其次将会考虑小样本学习在分类模型应用情况;最后将完成模型融合环节提升性能,并做可解释性分析。


本人博客:https://blog.csdn.net/sinat_39620217?type=blog

用户头像

汀丶

关注

还未添加个人签名 2022-01-06 加入

还未添加个人简介

评论

发布
暂无评论
在数据增强、蒸馏剪枝下ERNIE3.0分类模型性能提升_nlp_汀丶_InfoQ写作社区