写点什么

chatglm2-6b 在 P40 上做 LORA 微调 | 京东云技术团队

  • 2023-09-06
    北京
  • 本文字数:2787 字

    阅读完需:约 9 分钟

chatglm2-6b在P40上做LORA微调 | 京东云技术团队

背景:

目前,大模型的技术应用已经遍地开花。最快的应用方式无非是利用自有垂直领域的数据进行模型微调。chatglm2-6b 在国内开源的大模型上,效果比较突出。本文章分享的内容是用 chatglm2-6b 模型在集团 EA 的 P40 机器上进行垂直领域的 LORA 微调。

一、chatglm2-6b 介绍

github: https://github.com/THUDM/ChatGLM2-6B


chatglm2-6b 相比于 chatglm 有几方面的提升:


1. 性能提升: 相比初代模型,升级了 ChatGLM2-6B 的基座模型,同时在各项数据集评测上取得了不错的成绩;


2. 更长的上下文: 我们将基座模型的上下文长度(Context Length)由 ChatGLM-6B 的 2K 扩展到了 32K,并在对话阶段使用 8K 的上下文长度训练;


3. 更高效的推理: 基于 Multi-Query Attention 技术,ChatGLM2-6B 有更高效的推理速度和更低的显存占用:在官方的模型实现下,推理速度相比初代提升了 42%;


4. 更开放的协议:ChatGLM2-6B 权重对学术研究完全开放,在填写问卷进行登记后亦允许免费商业使用。

二、微调环境介绍

2.1 性能要求

推理这块,chatglm2-6b 在精度是 fp16 上只需要 14G 的显存,所以 P40 是可以 cover 的。



EA 上 P40 显卡的配置如下:


2.2 镜像环境

做微调之前,需要编译环境进行配置,我这块用的是 docker 镜像的方式来加载镜像环境,具体配置如下:


FROM base-clone-mamba-py37-cuda11.0-gpu
# mpichRUN yum install mpich
# create my own environmentRUN conda create -c https://mirrors.tuna.tsinghua.edu.cn/anaconda/cloud/conda-forge/ --override --yes --name py39 python=3.9# display my own environment in LauncherRUN source activate py39 \ && conda install --yes --quiet ipykernel \ && python -m ipykernel install --name py39 --display-name "py39"
# install your own requirement packageRUN source activate py39 \ && conda install -y -c https://mirrors.tuna.tsinghua.edu.cn/anaconda/cloud/pytorch/ \ pytorch torchvision torchaudio faiss-gpu \ && pip install --no-cache-dir --ignore-installed -i https://pypi.tuna.tsinghua.edu.cn/simple \ protobuf \ streamlit \ transformers==4.29.1 \ cpm_kernels \ mdtex2html \ gradio==3.28.3 \ sentencepiece \ accelerate \ langchain \ pymupdf \ unstructured[local-inference] \ layoutparser[layoutmodels,tesseract] \ nltk~=3.8.1 \ sentence-transformers \ beautifulsoup4 \ icetk \ fastapi~=0.95.0 \ uvicorn~=0.21.1 \ pypinyin~=0.48.0 \ click~=8.1.3 \ tabulate \ feedparser \ azure-core \ openai \ pydantic~=1.10.7 \ starlette~=0.26.1 \ numpy~=1.23.5 \ tqdm~=4.65.0 \ requests~=2.28.2 \ rouge_chinese \ jieba \ datasets \ deepspeed \ pdf2image \ urllib3==1.26.15 \ tenacity~=8.2.2 \ autopep8 \ paddleocr \ mpi4py \ tiktoken
复制代码


如果需要使用 deepspeed 方式来训练, EA 上缺少 mpich 信息传递工具包,需要自己手动安装。

2.3 模型下载

huggingface 地址: https://huggingface.co/THUDM/chatglm2-6b/tree/main

三、LORA 微调

3.1 LORA 介绍

paper: https://arxiv.org/pdf/2106.09685.pdf


LORA(Low-Rank Adaptation of Large Language Models)微调方法: 冻结预训练好的模型权重参数,在冻结原模型参数的情况下,通过往模型中加入额外的网络层,并只训练这些新增的网络层参数。



LoRA 的思想:


  • 在原始 PLM (Pre-trained Language Model) 旁边增加一个旁路,做一个降维再升维的操作。

  • 训练的时候固定 PLM 的参数,只训练降维矩阵 A 与升维矩 B。而模型的输入输出维度不变,输出时将 BA 与 PLM 的参数叠加。

  • 用随机高斯分布初始化 A,用 0 矩阵初始化 B,保证训练的开始此旁路矩阵依然是 0 矩阵。

3.2 微调

huggingface 提供的 peft 工具可以方便微调 PLM 模型,这里也是采用的 peft 工具来创建 LORA。


peft 的 github: https://gitcode.net/mirrors/huggingface/peft?utm_source=csdn_github_accelerator


加载模型和 lora 微调:


    # load model    tokenizer = AutoTokenizer.from_pretrained(args.model_dir, trust_remote_code=True)    model = AutoModel.from_pretrained(args.model_dir, trust_remote_code=True)        print("tokenizer:", tokenizer)        # get LoRA model    config = LoraConfig(        r=args.lora_r,        lora_alpha=32,        lora_dropout=0.1,        bias="none",)        # 加载lora模型    model = get_peft_model(model, config)    # 半精度方式    model = model.half().to(device)
复制代码


这里需要注意的是,用 huggingface 加载本地模型,需要创建 work 文件,EA 上没有权限在没有在.cache 创建,这里需要自己先制定 work 路径。


import osos.environ['TRANSFORMERS_CACHE'] = os.path.dirname(os.path.abspath(__file__))+"/work/"os.environ['HF_MODULES_CACHE'] = os.path.dirname(os.path.abspath(__file__))+"/work/"
复制代码


如果需要用 deepspeed 方式训练,选择你需要的 zero-stage 方式:


    conf = {"train_micro_batch_size_per_gpu": args.train_batch_size,            "gradient_accumulation_steps": args.gradient_accumulation_steps,            "optimizer": {                "type": "Adam",                "params": {                    "lr": 1e-5,                    "betas": [                        0.9,                        0.95                    ],                    "eps": 1e-8,                    "weight_decay": 5e-4                }            },            "fp16": {                "enabled": True            },            "zero_optimization": {                "stage": 1,                "offload_optimizer": {                    "device": "cpu",                    "pin_memory": True                },                "allgather_partitions": True,                "allgather_bucket_size": 2e8,                "overlap_comm": True,                "reduce_scatter": True,                "reduce_bucket_size": 2e8,                "contiguous_gradients": True            },            "steps_per_print": args.log_steps            }
复制代码


其他都是数据处理处理方面的工作,需要关注的就是怎么去构建 prompt,个人认为在领域内做微调构建 prompt 非常重要,最终对模型的影响也比较大。

四、微调结果

目前模型还在 finetune 中,batch=1,epoch=3,已经迭代一轮。



作者:京东零售 郑少强

来源:京东云开发者社区 转载请注明来源

发布于: 6 小时前阅读数: 2
用户头像

拥抱技术,与开发者携手创造未来! 2018-11-20 加入

我们将持续为人工智能、大数据、云计算、物联网等相关领域的开发者,提供技术干货、行业技术内容、技术落地实践等文章内容。京东云开发者社区官方网站【https://developer.jdcloud.com/】,欢迎大家来玩

评论

发布
暂无评论
chatglm2-6b在P40上做LORA微调 | 京东云技术团队_人工智能_京东科技开发者_InfoQ写作社区