MedicalGPT:基于 LLaMA-13B 的中英医疗问答模型(LoRA)
MedicalGPT:基于 LLaMA-13B 的中英医疗问答模型(LoRA)、实现包括二次预训练、有监督微调、奖励建模、强化学习训练[LLM:含 Ziya-LLaMA]。
**** 训练医疗大模型,实现包括二次预训练、有监督微调、奖励建模、强化学习训练。
分四阶段训练 GPT 模型,来自 Andrej Karpathy 的演讲 PDF State of GPT,视频 Video
版本迭代
V1:发布中文医疗 LoRA 模型,基于 Ziya-LLaMA-13B-v1 模型,SFT 微调了一版医疗模型,医疗问答效果有提升,发布微调后的 LoRA 权重,
V0:以医疗为例,训练领域大模型,实现了四阶段训练:包括二次预训练、有监督微调、奖励建模、强化学习训练。
基于 ChatGPT Training Pipeline,本项目实现了领域模型--医疗模型的四阶段训练:
第一阶段:PT(Continue PreTraining)增量预训练,在海量领域文档数据上二次预训练 GPT 模型,以注入领域知识
第二阶段:SFT(Supervised Fine-tuning)有监督微调,构造指令微调数据集,在预训练模型基础上做指令精调,以对齐指令意图
第三阶段:RM(Reward Model)奖励模型建模,构造人类偏好排序数据集,训练奖励模型,用来对齐人类偏好,主要是"HHH"原则,具体是"helpful, honest, harmless"
第四阶段:RL(Reinforcement Learning)基于人类反馈的强化学习(RLHF),用奖励模型来训练 SFT 模型,生成模型使用奖励或惩罚来更新其策略,以便生成更高质量、更符合人类偏好的文本
1.模型介绍
1.1 基于 LLaMA-13B 的中英医疗问答模型(LoRA)
在中文开放测试集中的表现优异,继承了两方面的优势:
微调训练的底座是 Ziya-LLaMA-13B 模型,是较强的中英文底座模型,
微调使用的是高质量 240 万条中英文医疗指令数据集,和多种通用指令数据集,微调后的模型在医疗行业答复能力达到领先水平,在通用问题上的答复能力不弱于 LLaMA-13B。
1.1.1 训练评估结果
training args:
train loss:
evaluate loss:
本项目开源在 github repo:
使用 textgen 库:textgen,可调用 LLaMA 模型:
Install package:
1.1.2 HuggingFace Transformers
Without textgen, you can use the model like this:
First, you pass your input through the transformer model, then you get the generated sentence.
Install package:
output:
模型文件组成:
1.1.3 预测结果
1.1.4 训练数据集
50 万条中文 ChatGPT 指令 Belle 数据集:BelleGroup/train_0.5M_CN
100 万条中文 ChatGPT 指令 Belle 数据集:BelleGroup/train_1M_CN
5 万条英文 ChatGPT 指令 Alpaca 数据集:50k English Stanford Alpaca dataset
2 万条中文 ChatGPT 指令 Alpaca 数据集:shibing624/alpaca-zh
69 万条中文指令 Guanaco 数据集 (Belle50 万条 + Guanaco19 万条):Chinese-Vicuna/guanaco_belle_merge_v1.0
240 万条中文医疗数据集 (包括预训练数据和指令微调数据集):shibing624/medical
如果需要训练 ChatGLM/LLAMA/BLOOM 模型,请参考 https://github.com/shibing624/textgen
1.2 姜子牙系列模型
1.2.1 简介
姜子牙通用大模型 V1 是基于 LLaMa 的 130 亿参数的大规模预训练模型,具备翻译,编程,文本分类,信息抽取,摘要,文案生成,常识问答和数学计算等能力。目前姜子牙通用大模型已完成大规模预训练、多任务有监督微调和人类反馈学习三阶段的训练过程。
The Ziya-LLaMA-13B-v1 is a large-scale pre-trained model based on LLaMA with 13 billion parameters. It has the ability to perform tasks such as translation, programming, text classification, information extraction, summarization, copywriting, common sense Q&A, and mathematical calculation. The Ziya-LLaMA-13B-v1 has undergone three stages of training: large-scale continual pre-training (PT), multi-task supervised fine-tuning (SFT), and human feedback learning (RM, PPO).
软件依赖
1.2.2 模型信息 Model Information
继续预训练 Continual pretraining
原始数据包含英文和中文,其中英文数据来自 openwebtext、Books、Wikipedia 和 Code,中文数据来自清洗后的悟道数据集、自建的中文数据集。在对原始数据进行去重、模型打分、数据分桶、规则过滤、敏感主题过滤和数据评估后,最终得到 125B tokens 的有效数据。
为了解决 LLaMA 原生分词对中文编解码效率低下的问题,我们在 LLaMA 词表的基础上增加了 7k + 个常见中文字,通过和 LLaMA 原生的词表去重,最终得到一个 39410 大小的词表,并通过复用 Transformers 里 LlamaTokenizer 来实现了这一效果。
在增量训练过程中,我们使用了 160 张 40GB 的 A100,采用 2.6M tokens 的训练集样本数量和 FP 16 的混合精度,吞吐量达到 118 TFLOP per GPU per second。因此我们能够在 8 天的时间里在原生的 LLaMA-13B 模型基础上,增量训练 110B tokens 的数据。
训练期间,虽然遇到了机器宕机、底层框架 bug、loss spike 等各种问题,但我们通过快速调整,保证了增量训练的稳定性。我们也放出训练过程的 loss 曲线,让大家了解可能出现的问题。
1.2.3 多任务有监督微调 Supervised finetuning
在多任务有监督微调阶段,采用了课程学习(curiculum learning)和增量训练(continual learning)的策略,用大模型辅助划分已有的数据难度,然后通过 “Easy To Hard” 的方式,分多个阶段进行 SFT 训练。
SFT 训练数据包含多个高质量的数据集,均经过人工筛选和校验:
Self-Instruct 构造的数据(约 2M):BELLE、Alpaca、Alpaca-GPT4 等多个数据集
内部收集 Code 数据(300K):包含 leetcode、多种 Code 任务形式
内部收集推理 / 逻辑相关数据(500K):推理、申论、数学应用题、数值计算等
中英平行语料(2M):中英互译语料、COT 类型翻译语料、古文翻译语料等
多轮对话语料(500K):Self-Instruct 生成、任务型多轮对话、Role-Playing 型多轮对话等
在多任务学习的监督微调(SFT)阶段,我们使用了课程学习和增量训练策略。我们利用大模型辅助对现有数据进行难度划分,然后采用“由易到难”的方法分阶段进行 SFT 训练。
SFT 训练数据由多个人工选择和验证的高质量数据集组成,包括 BELLE、Alpaca 和 Alpaca- gpt4 等数据集的约 200 万样本,包括 LeetCode 和各种代码任务在内的内部采集代码数据的 30 万样本,推理、议论文、数学应用问题和数值计算等内部采集推理/逻辑相关数据的 50 万样本。200 万个汉英平行语料库样本,包括翻译、cot 式翻译、文言文翻译;50 万个多回合对话语料库样本,包括自主生成、任务导向多回合对话、角色扮演多回合对话。
1.2.4 人类反馈学习 Human-Feedback training
为了进一步提升模型的综合表现,使其能够充分理解人类意图、减少 “幻觉” 和不安全的输出,基于指令微调后的模型,进行了人类反馈训练(Human-Feedback Training,HFT)。在训练中,我们采用了以人类反馈强化学习(RM、PPO)为主,结合多种其他手段联合训练的方法,手段包括人类反馈微调(Human-Feedback Fine-tuning,HFFT)、后见链微调(Chain-of-Hindsight Fine-tuning,COHFT)、AI 反馈(AI Feedback)和基于规则的奖励系统(Rule-based Reward System,RBRS)等,用来弥补 PPO 方法的短板,加速训练。
我们在内部自研的框架上实现了 HFT 的训练流程,该框架可以利用最少 8 张 40G 的 A100 显卡完成 Ziya-LLaMA-13B-v1 的全参数训练。在 PPO 训练中,我们没有限制生成样本的长度,以确保长文本任务的奖励准确性。每次训练的总经验池尺寸超过 100k 样本,确保了训练的充分性。
1.2.5 效果评估 Performance
2.Demo 展示
Hugging Face Demo: doing
我们提供了一个简洁的基于 gradio 的交互式 web 界面,启动服务后,可通过浏览器访问,输入问题,模型会返回答案。
启动服务,命令如下:
参数说明:
--model_type {base_model_type}
:预训练模型类型,如 llama、bloom、chatglm 等--base_model {base_model}
:存放 HF 格式的 LLaMA 模型权重和配置文件的目录,也可使用 HF Model Hub 模型调用名称--lora_model {lora_model}
:LoRA 文件所在目录,也可使用 HF Model Hub 模型调用名称。若 lora 权重已经合并到预训练模型,则删除--lora_model 参数--tokenizer_path {tokenizer_path}
:存放对应 tokenizer 的目录。若不提供此参数,则其默认值与--base_model 相同--use_cpu
: 仅使用 CPU 进行推理--gpus {gpu_ids}
: 指定使用的 GPU 设备编号,默认为 0。如使用多张 GPU,以逗号分隔,如 0,1,2
2.1 环境安装
Updating the requirementsFrom time to time, the
requirements.txt
changes. To update, use this command:
2.2 Pipeline 训练
Training Stage:
提供完整四阶段串起来训练的 pipeline:run_training_pipeline.ipynb ,其对应的 colab:
,运行完大概需要 15 分钟,我运行成功后的副本 colab:
2.3 模型支持
The following models are tested:
bloom:
llama:
chatglm:
baichuan:
2.4 模型训练
2.4.1 PT(Continue PreTraining)增量预训练
第一阶段:PT(Continue PreTraining)增量预训练
使用百科类文档类数据集,用来在领域数据集上增量预训练或二次预训练,期望能把领域知识注入给模型,以医疗领域为例,希望增量预训练,能让模型理解感冒的症状、病因、治疗药品、治疗方法、药品疗效等知识,便于后续的 SFT 监督微调能激活这些内在知识。
这里说明一点,像 GPT3、LLaMA 这样的大模型理论上是可以从增量预训练中获益,但增量预训练需要满足两个要求:1)高质量的预训练样本;2)较大的计算资源,显存要求高,即使是用 LoRA 技术,也要满足 block_size=1024 或 2048 长度的文本加载到显存中。
其次,如果你的项目用到的数据是模型预训练中已经使用了的,如维基百科、ArXiv 等 LLaMA 模型预训练用了的,则这些数据是没有必要再喂给 LLaMA 增量预训练,而且预训练样本的质量如果不够高,也可能会损害原模型的生成能力。
tips:PT 阶段是可选项,请慎重处理。
基于 llama-7b 模型,使用医疗百科类数据继续预训练,期望注入医疗知识到预训练模型,得到 llama-7b-pt 模型
Continue pretraining of the base llama-7b model to create llama-7b-pt:
如果你的显存不足,可以改小 batch_size=1, block_size=512(影响训练的上下文最大长度);
如果你的显存更大,可以改大 block_size=2048, 此为 llama 原始预训练长度,不能更大啦;调大 batch_size。
2.4.2 SFT(Supervised Fine-tuning)有监督微调
第二阶段:SFT(Supervised Fine-tuning)有监督微调
基于 llama-7b-pt 模型,使用医疗问答类数据进行有监督微调,得到 llama-7b-sft 模型
Supervised fine-tuning of the base llama-7b-pt model to create llama-7b-sft
2.4.3 RM(Reward Model)奖励模型建模
第三阶段:RM(Reward Model)奖励模型建模
RM(Reward Model)奖励模型,原则上,我们可以直接用人类标注来对模型做 RLHF 微调。
然而,这将需要我们给人类发送一些样本,在每轮优化后计分。这是贵且慢的,因为收敛需要的训练样本量大,而人类阅读和标注的速度有限。一个比直接反馈更好的策略是,在进入 RL 循环之前用人类标注集来训练一个奖励模型 RM。奖励模型的目的是模拟人类对文本的打分。
构建奖励模型的最佳实践是预测结果的排序,即对每个 prompt (输入文本) 对应的两个结果 (yk, yj),模型预测人类标注的比分哪个更高。RM 模型是通过人工标注 SFT 模型的打分结果来训练的,目的是取代人工打分,本质是个回归模型,用来对齐人类偏好,主要是"HHH"原则,具体是"helpful, honest, harmless"。
基于 llama-7b-sft 模型,使用医疗问答偏好数据训练奖励偏好模型,训练得到 llama-7b-reward 模型
Reward modeling using dialog pairs from the reward dataset using the llama-7b-sft to create llama-7b-reward:
2.4.4 基于人类反馈的强化学习(RLHF)
第四阶段:RL(Reinforcement Learning)基于人类反馈的强化学习(RLHF)
RL(Reinforcement Learning)模型的目的是最大化奖励模型的输出,基于上面步骤,我们有了微调的语言模型(llama-7b-sft)和奖励模型(llama-7b-reward),可以开始执行 RL 循环了。
这个过程大致分为三步:
输入 prompt,模型生成答复
用奖励模型来对答复评分
基于评分,进行一轮策略优化的强化学习(PPO)
基于 llama-7b-reward 模型 RL 微调训练 llama-7b-sft 模型,得到 llama-7b-rl 模型
Reinforcement Learning fine-tuning of llama-7b-sft with the llama-7b-reward reward model to create llama-7b-rl
2.5 推理预测
训练完成后,现在我们加载训练好的模型,验证模型生成文本的效果。
参数说明:
--model_type {base_model_type}
:预训练模型类型,如 llama、bloom、chatglm 等--base_model {base_model}
:存放 HF 格式的 LLaMA 模型权重和配置文件的目录--lora_model {lora_model}
:LoRA 解压后文件所在目录,也可使用 HF Model Hub 模型调用名称。如果已经合并了 LoRA 权重到预训练模型,则可以不提供此参数--tokenizer_path {tokenizer_path}
:存放对应 tokenizer 的目录。若不提供此参数,则其默认值与--base_model 相同--with_prompt
:是否将输入与 prompt 模版进行合并。如果加载 Alpaca 模型,请务必启用此选项!--interactive
:以交互方式启动,以便进行多次单轮问答--data_file {file_name}
:非交互方式启动下,按行读取 file_name 中的的内容进行预测--predictions_file {file_name}
:非交互式方式下,将预测的结果以 json 格式写入 file_name--use_cpu
: 仅使用 CPU 进行推理--gpus {gpu_ids}
: 指定使用的 GPU 设备编号,默认为 0。如使用多张 GPU,以逗号分隔,如 0,1,2
2.5.1 推理样例
shibing624/ziya-llama-13b-medical-lora inference examples:
3.数据集
3.1 医疗数据集
240 万条中文医疗数据集(包括预训练、指令微调和奖励数据集):shibing624/medical
22 万条中文医疗对话数据集(华佗项目):FreedomIntelligence/HuatuoGPT-sft-data-v1
3.2 通用数据集
3.2.1 SFT datasets
50 万条中文 ChatGPT 指令 Belle 数据集:BelleGroup/train_0.5M_CN
100 万条中文 ChatGPT 指令 Belle 数据集:BelleGroup/train_1M_CN
5 万条英文 ChatGPT 指令 Alpaca 数据集:50k English Stanford Alpaca dataset
2 万条中文 ChatGPT 指令 Alpaca 数据集:shibing624/alpaca-zh
69 万条中文指令 Guanaco 数据集(Belle50 万条+Guanaco19 万条):Chinese-Vicuna/guanaco_belle_merge_v1.0
5 万条英文 ChatGPT 多轮对话数据集:RyokoAI/ShareGPT52K
80 万条中文 ChatGPT 多轮对话数据集:BelleGroup/multiturn_chat_0.8M
116 万条中文 ChatGPT 多轮对话数据集:fnlp/moss-002-sft-data
3.2.2 Reward Model datasets
原版的 oasst1 数据集:OpenAssistant/oasst1
2 万条多语言 oasst1 的 reward 数据集:tasksource/oasst1_pairwise_rlhf_reward
11 万条英文 hh-rlhf 的 reward 数据集:Dahoas/full-hh-rlhf
9 万条英文 reward 数据集(来自 Anthropic's Helpful Harmless dataset):Dahoas/static-hh
7 万条英文 reward 数据集(来源同上):Dahoas/rm-static
7 万条繁体中文的 reward 数据集(翻译自 rm-static)liswei/rm-static-m2m100-zh
7 万条英文 Reward 数据集:yitingxie/rlhf-reward-datasets
3 千条中文知乎问答偏好数据集:liyucheng/zhihu_rlhf_3k
更多优质内容请关注公号:汀丶人工智能;会提供一些相关的资源和优质文章,免费获取阅读。
版权声明: 本文为 InfoQ 作者【汀丶人工智能】的原创文章。
原文链接:【http://xie.infoq.cn/article/45eb0bb82a5eca1d563f74d2b】。
本文遵守【CC-BY 4.0】协议,转载请保留原文出处及本版权声明。
评论