写点什么

文盘 rust-- 使用 Rust 构建 RAG

  • 2024-10-08
    北京
  • 本文字数:9856 字

    阅读完需:约 32 分钟

作者:京东科技 贾世闻


RAG(Retrieval-Augmented Generation)技术在 AI 生态系统中扮演着至关重要的角色,特别是在提升大型语言模型(LLMs)的准确性和应用范围方面。RAG 通过结合检索技术与 LLM 提示,从各种数据源检索相关信息,并将其与用户的问题结合,生成准确且丰富的回答。这一机制特别适用于需要应对信息不断更新的场景,因为大语言模型所依赖的参数知识本质上是静态的。


RAG 技术的优势在于它能够利用外部知识库,引用大量的信息,以提供更深入、准确且有价值的答案,提高了生成文本的可靠性。此外,RAG 模型具备检索库的更新机制,可以实现知识的即时更新,无需重新训练模型,这在及时性要求高的应用中占优势。


目前构建一个 RAG 并不是一个非常的事情。使用 Langchain 等成熟技术架构百十行代码就能构建一个 Demo。那能不能利用目前的 Rust 生态构建一个简易的 RAG。说干就干,本期和大家聊聊如果使用 rust 语言构建 rag。

构建知识库

知识库构建主要是模型+向量库,为了保证所有系统中所有组件都使用 rust 构建,在限量数据库的选型上我们使用qdrant,纯 rust 构建的向量数据库。


知识库的构建最重要的步骤是 embedding 的过程。


过程如下:


  • 模型加载

  • 获取文本 token

  • 通过模型获取文本的 Embedding

  • 下面详细介绍每个过程细节及代码实现。

模型加载

以下代码用于加载模型和 tokenizer


async fn build_model_and_tokenizer(model_config: &ConfigModel) -> Result<(BertModel, Tokenizer)> {    let device = Device::new_cuda(0)?;    let repo = Repo::with_revision(        model_config.model_id.clone(),        RepoType::Model,        model_config.revision.clone(),    );    let (config_filename, tokenizer_filename, weights_filename) = {        let api = ApiBuilder::new()                .build()?;        let api = api.repo(repo);        let config = api.get("config.json").await?;        let tokenizer = api.get("tokenizer.json").await?;        let weights = if model_config.use_pth {            api.get("pytorch_model.bin").await?        } else {            api.get("model.safetensors").await?        };        (config, tokenizer, weights)A    };    let config = std::fs::read_to_string(config_filename)?;    let mut config: Config = serde_json::from_str(&config)?;    let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
let vb = if model_config.use_pth { VarBuilder::from_pth(&weights_filename, DTYPE, &device)? } else { unsafe { VarBuilder::from_mmaped_safetensors(&[weights_filename], DTYPE, &device)? } }; if model_config.approximate_gelu { config.hidden_act = HiddenAct::GeluApproximate; } let model = BertModel::load(vb, &config)?; Ok((model, tokenizer))}
复制代码


模型和 tokenizer 是系统中频繁调用的部分,所以为了避免重复加载,通过 OnceCell 构建静态全局变量


pub static GLOBAL_EMBEDDING_MODEL: OnceCell> = OnceCell::const_new();
pub async fn init_model_and_tokenizer() -> Arc<(BertModel, Tokenizer)> { let config = get_config().unwrap(); let (m, t) = build_model_and_tokenizer(&config.model).await.unwrap(); Arc::new((m, t))}
复制代码


在系统启动时加载模型


GLOBAL_RUNTIME.block_on(async {    log::info!("global runtime start!");    // 加载model    GLOBAL_EMBEDDING_MODEL        .get_or_init(init_model_and_tokenizer)        .await;});
复制代码


Embedding 过程主要由一下函数实现。


pub async fn embedding_setence(content: &str) -> Result>> {    let m_t = GLOBAL_EMBEDDING_MODEL.get().unwrap();    let tokens = m_t        .1        .encode(content, true)        .map_err(E::msg)?        .get_ids()        .to_vec();    let token_ids = Tensor::new(&tokens[..], &m_t.0.device)?.unsqueeze(0)?;    let token_type_ids = token_ids.zeros_like()?;    let sequence_output = m_t.0.forward(&token_ids, &token_type_ids)?;    let (_n_sentence, n_tokens, _hidden_size) = sequence_output.dims3()?;    let embeddings = (sequence_output.sum(1)? / (n_tokens as f64))?;    let embeddings = normalize_l2(&embeddings)?;    let encodings = embeddings.to_vec2::()?;    Ok(encodings)}
复制代码


函数通过 tokenizer encode 输入的文本,再使用模型 embed token 获取一个三维的 Tensor,最后归一化张量。

数据入库

知识库构建是将待检索文本向量化后存储到向量数据库的过程。


本次使用京东云文档作为原始文本,加工为以下格式。数据加工过程这里就不累述了。


{    "content": "# 服务计费\n\n主机迁移服务自身为免费服务,但是迁移目标为云主机镜像时,迁移过程依赖系统自动创建的 中转资源的配合,这些资源中涉及部分付费资源,会产生相应费用。\n\n迁移过程涉及的中转资付费资源配置及计费说明如下(单个迁移任务):\n\n|  | 云主机 | 云硬盘 | 弹性公网IP |\n| --- | --- | --- | ------ |\n| 计费类型 | 按配置 | 按配置 | 按用量 |\n| 规格配置 | 2C4G (c.n2.large或c.n3.large或c.n1.large) | 系统盘:40G 通用型SSD 数据盘:通用型SSD,数量及容量取决于源服务器系统盘及数据盘情况 | 30Mbps |\n| 费用预估 | 云主机规格每小时价格\\*迁移时长 | 云硬盘规格每小时价格\\*迁移时长 | 弹性公网IP每小时保有费\\*迁移时长 仅使用弹性公网IP入方向流量,只涉及IP保有用,不涉及流量费用 |\n\n> 提示:\n>\n> * 迁移时长取决于源服务器迁数据量以及源服务器公网出方向带宽,公网连接顺畅且源服务器公网出方向带宽不低于22.5Mbps的情况下(主机迁移为单线程传输,京东云云主机在单流传输下实际带宽为带宽上限的75%左右),实际数据容量为5GB的磁盘迁移时长在30分钟左右。\n> * 中转实例实例绑定的安全组出方向默认拒绝所有流量,因此默认情况下降不会产生任何公网出方向收费流量,但此配置也影响了云主机部分监控指标的上报,如需要监控中转实例的全部监控数据,可自行调整安全组规则方向出方向443端口。",    "title": "服务计费说明",    "product": "云主机 CVM",    "url": "https://docs.jdcloud.com/cn/virtual-machines/server-migration-service/billing"}
复制代码


入库完整代码如下:


use anyhow::Error as E;use anyhow::Result;use candle_core::Device;use candle_core::Tensor;use candle_nn::VarBuilder;use candle_transformers::models::bert::{BertModel, Config, HiddenAct, DTYPE};use hf_hub::{api::tokio::Api, Repo, RepoType};use qdrant_client::qdrant::CollectionExistsRequest;use qdrant_client::qdrant::CreateCollectionBuilder;use qdrant_client::qdrant::DeleteCollection;use qdrant_client::qdrant::Distance;use qdrant_client::qdrant::UpsertPointsBuilder;use qdrant_client::qdrant::VectorParamsBuilder;use qdrant_client::Payload;use qdrant_client::{    qdrant::{        CollectionOperationResponse, CreateCollection, PointStruct, PointsOperationResponse,        UpsertPoints,    },    Qdrant,};use serde::{Deserialize, Serialize};use serde_json::from_str;use std::fs;use std::sync::Arc;use tokenizers::Tokenizer;use tokio::sync::OnceCell;use uuid::Uuid;use walkdir::WalkDir;
#[derive(Debug, Serialize, Deserialize, Clone)]pub struct Doc { pub content: String, pub title: String, pub product: String, pub url: String,}
#[derive(Debug, PartialEq, Serialize, Deserialize, Clone)]pub struct ModelConfig { #[serde(default = "ModelConfig::model_id_default")] pub model_id: String, #[serde(default = "ModelConfig::revision_default")] pub revision: String, #[serde(default = "ModelConfig::use_pth_default")] pub use_pth: bool, #[serde(default = "ModelConfig::approximate_gelu_default")] pub approximate_gelu: bool,}
impl Default for ModelConfig { fn default() -> Self { Self { model_id: Self::model_id_default(), revision: Self::revision_default(), use_pth: Self::use_pth_default(), approximate_gelu: Self::approximate_gelu_default(), } }}
impl ModelConfig { fn model_id_default() -> String { "moka-ai/m3e-large".to_string() } fn revision_default() -> String { "main".to_string() } fn use_pth_default() -> bool { true } fn approximate_gelu_default() -> bool { false }}
pub static GLOBAL_MODEL: OnceCell> = OnceCell::const_new();pub static GLOBAL_TOKEN: OnceCell> = OnceCell::const_new();
pub async fn init_model() -> Arc { let config = ModelConfig::default(); let (m, _) = build_model_and_tokenizer(&config).await.unwrap(); Arc::new(m)}
pub async fn init_tokenizer() -> Arc { let config = ModelConfig::default(); let (_, t) = build_model_and_tokenizer(&config).await.unwrap(); Arc::new(t)}
async fn build_model_and_tokenizer(model_config: &ModelConfig) -> Result<(BertModel, Tokenizer)> { let device = Device::new_cuda(0)?; let repo = Repo::with_revision( model_config.model_id.clone(), RepoType::Model, model_config.revision.clone(), ); let (config_filename, tokenizer_filename, weights_filename) = { let api = Api::new()?; let api = api.repo(repo); let config = api.get("config.json").await?; let tokenizer = api.get("tokenizer.json").await?; let weights = if model_config.use_pth { api.get("pytorch_model.bin").await? } else { api.get("model.safetensors").await? }; (config, tokenizer, weights) }; let config = std::fs::read_to_string(config_filename)?; let mut config: Config = serde_json::from_str(&config)?; let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
let vb = if model_config.use_pth { VarBuilder::from_pth(&weights_filename, DTYPE, &device)? } else { unsafe { VarBuilder::from_mmaped_safetensors(&[weights_filename], DTYPE, &device)? } }; if model_config.approximate_gelu { config.hidden_act = HiddenAct::GeluApproximate; } let model = BertModel::load(vb, &config)?; Ok((model, tokenizer))}
pub async fn embedding_setence(content: &str) -> Result>> { let m = GLOBAL_MODEL.get().unwrap(); let t = GLOBAL_TOKEN.get().unwrap(); let tokens = t.encode(content, true).map_err(E::msg)?.get_ids().to_vec();
let token_ids = Tensor::new(&tokens[..], &m.device)?.unsqueeze(0)?; let token_type_ids = token_ids.zeros_like()?;
let sequence_output = m.forward(&token_ids, &token_type_ids)?; let (_n_sentence, n_tokens, _hidden_size) = sequence_output.dims3()?; let embeddings = (sequence_output.sum(1).unwrap() / (n_tokens as f64))?; let embeddings = normalize_l2(&embeddings).unwrap(); let encodings = embeddings.to_vec2::()?; Ok(encodings)}
pub fn normalize_l2(v: &Tensor) -> Result { Ok(v.broadcast_div(&v.sqr()?.sum_keepdim(1)?.sqrt()?)?)}
pub struct QdrantClient { client: Qdrant,}
impl QdrantClient { pub async fn create_collection( &self, request: impl Into, ) -> Result { let resp = self.client.create_collection(request).await?; Ok(resp) }
pub async fn delete_collection( &self, request: impl Into, ) -> Result { let resp = self.client.delete_collection(request).await?; Ok(resp) }
pub async fn collection_exists( &self, request: impl Into, ) -> Result { let resp = self.client.collection_exists(request).await?; Ok(resp) }
pub async fn load_dir(&self, path: &str, collection_name: &str) { let mut points = vec![]; for entry in WalkDir::new(path) .into_iter() .filter_map(Result::ok) .filter(|e| !e.file_type().is_dir() && e.file_name().to_str().is_some()) { if let Some(p) = entry.path().to_str() { let id = Uuid::new_v4(); let content = match fs::read_to_string(p) { Ok(c) => c, Err(_) => continue, };
let doc = match from_str::(content.as_str()) { Ok(d) => d, Err(_) => continue, }; let mut payload = Payload::new(); payload.insert("content", doc.content); payload.insert("title", doc.title); payload.insert("product", doc.product); payload.insert("url", doc.url); let vector_contens = embedding_setence(content.as_str()).await.unwrap(); let ps = PointStruct::new(id.to_string(), vector_contens[0].clone(), payload); points.push(ps);
if points.len().eq(&100) { let p = points.clone(); self.client .upsert_points(UpsertPointsBuilder::new(collection_name, p).wait(true)) .await .unwrap(); points.clear(); println!("batch finish"); } } }
if points.len().gt(&0) { self.client .upsert_points(UpsertPointsBuilder::new(collection_name, points).wait(true)) .await .unwrap(); } }}
#[tokio::main]async fn main() { // 加载模型 GLOBAL_MODEL.get_or_init(init_model).await; GLOBAL_TOKEN.get_or_init(init_tokenizer).await;
let collection_name = "default_collection";
// The Rust client uses Qdrant's GRPC interface let qdrant = Qdrant::from_url("http://localhost:6334").build().unwrap(); let qdrant_client = QdrantClient { client: qdrant };
if !qdrant_client .collection_exists(collection_name) .await .unwrap() { qdrant_client .create_collection( CreateCollectionBuilder::new(collection_name) .vectors_config(VectorParamsBuilder::new(1024, Distance::Dot)), ) .await .unwrap(); }
qdrant_client .load_dir("/root/jd_docs", collection_name) .await;
println!("{:?}", qdrant_client.client.health_check().await);}
复制代码


以上代码要完成的任务如下:

推理服务

推理服务使用 rust 构建的 mistral.rs


由于国内访问 hf 并不方便所以先通过 https://hf-mirror.com/ 现将模型下载到本地。本次使用 qwen 模型


HF_ENDPOINT="https://hf-mirror.com"  huggingface-cli download --repo-type model --resume-download Qwen/Qwen2-7B --local-dir /root/Qwen2-7B
复制代码


启动 mistralrs-server


git clone https://github.com/EricLBuehler/mistral.rscd mistral.rscargo run  --bin mistralrs-server  --features cuda -- --port 3333 plain -m /root/Qwen2-7B  -a qwen2
复制代码

推理服务调用

mistral.rs 支持 Openai 的 api 接口,使用 openai-api-rs调用即可。推理时间比较长 timeout 要设置长一些,若 timeout 时间太短有可能不等返回结果就已经强制超时。


pub static GLOBAL_OPENAI_CLIENT: Lazy> = Lazy::new(|| {    let mut client =        OpenAIClient::new_with_endpoint("http://10.0.0.7:3333/v1".to_string(), "EMPTY".to_string());    client.timeout = Some(30);    Arc::new(client)});
pub async fn inference(content: &str, max_len: i64) -> Result> { let req = ChatCompletionRequest::new( "".to_string(), vec![chat_completion::ChatCompletionMessage { role: chat_completion::MessageRole::user, content: chat_completion::Content::Text(content.to_string()), name: None, tool_calls: None, tool_call_id: None, }], ) .max_tokens(max_len);
let cr = GLOBAL_OPENAI_CLIENT.chat_completion(req).await?; Ok(cr.choices[0].message.content.clone())}
复制代码


将 Retriever 和推理服务集成


pub async fn answer(question: &str, max_len: i64) -> Result> {    let retriver = retriever(question, 1).await?;    let mut context = "".to_string();
for sp in retriver.result { let payload = sp.payload; let product = payload.get("product").unwrap().to_string(); let title = payload.get("title").unwrap().to_string(); let content = payload.get("content").unwrap().to_string(); context.push_str(&product); context.push_str(&title); context.push_str(&content); }
let prompt = format!( "你是一个云技术专家, 使用以下检索到的Context回答问题。用中文回答问题。 Question: {} Context: {} ", question, context );
log::info!("{}", prompt);
let req = ChatCompletionRequest::new( "".to_string(), vec![chat_completion::ChatCompletionMessage { role: chat_completion::MessageRole::user, content: chat_completion::Content::Text(prompt), name: None, tool_calls: None, tool_call_id: None, }], ) .max_tokens(max_len);
let cr = GLOBAL_OPENAI_CLIENT.chat_completion(req).await?; Ok(cr.choices[0].message.content.clone())}
复制代码

后记

完整工程地址[embedding_server]https://github.com/jiashiwen/embedding_server


后续工程问题,多卡推理,多机推理,推理加速


资源对比


  • GPU 型号


    |=========================================+========================+======================|    |   0  NVIDIA A30                     Off |   00000000:00:07.0 Off |                    0 |    | N/A   30C    P0             29W /  165W |       0MiB /  24576MiB |      0%      Default |    |                                         |                        |             Disabled |    +-----------------------------------------+------------------------+----------------------+
复制代码


  • Embedding 资源

  • m3e-large

  • vllm


            +-----------------------------------------------------------------------------------------+            | Processes:                                                                              |            |  GPU   GI   CI        PID   Type   Process name                              GPU Memory |            |        ID   ID                                                               Usage      |            |=========================================================================================|            |    0   N/A  N/A    822789      C   ...iprojects/rag_demo/.venv/bin/python       1550MiB |            +-----------------------------------------------------------------------------------------+
复制代码


    -   candle
复制代码


            +-----------------------------------------------------------------------------------------+            | Processes:                                                                              |            |  GPU   GI   CI        PID   Type   Process name                              GPU Memory |            |        ID   ID                                                               Usage      |            |=========================================================================================|            |    0   N/A  N/A    823261      C   target/debug/embedding_server                1484MiB |            +-----------------------------------------------------------------------------------------+
复制代码


  • 推理资源

  • Qwen1.5-1.8B-Chat

  • vllm


            |=========================================================================================|            |    0   N/A  N/A    822437      C   /usr/bin/python3                            20440MiB |            +-----------------------------------------------------------------------------------------+
复制代码


    -   mistral.rs
复制代码


            |=========================================================================================|            |    0   N/A  N/A    822174      C   target/debug/mistralrs-server               22134MiB |            +-----------------------------------------------------------------------------------------+
复制代码


-   Qwen2-7B
- vllm 现存溢出
复制代码


            [rank0]: OutOfMemoryError: CUDA out of memory. Tried to allocate 9.25 GiB. GPU
复制代码


    -   mistral.rs
复制代码


            |=========================================================================================|            |    0   N/A  N/A    656923      C   target/debug/mistralrs-server               22006MiB |            +-----------------------------------------------------------------------------------------+
复制代码


从实际情况来看,Embedding 模型再资源占用情况 rust candle 框架使用显存略小些;推理模型 Qwen1.5-1.8B-Chat,vllm 资源占用略小。Qwen2-7B vllm 直接显存溢出。

大部分框架中使用 hf-hub 采用同步调用,不支持境内的 mirror。动手改造


src/api/tokio.rs



impl ApiBuilder { /// Set endpoint example 'https://hf-mirror.com' pub fn with_endpoint(mut self, endpoint: &str) -> Self { self.endpoint = endpoint.to_string(); self }}
复制代码


发布于: 刚刚阅读数: 5
用户头像

拥抱技术,与开发者携手创造未来! 2018-11-20 加入

我们将持续为人工智能、大数据、云计算、物联网等相关领域的开发者,提供技术干货、行业技术内容、技术落地实践等文章内容。京东云开发者社区官方网站【https://developer.jdcloud.com/】,欢迎大家来玩

评论

发布
暂无评论
文盘rust--使用 Rust 构建RAG_京东科技开发者_InfoQ写作社区