前置知识
BM25 简介
BM25 算法(Best Matching 25)是一种广泛用于信息检索领域的排名函数,用于在给定查询(Query)时对一组文档(Document)进行评分和排序。BM25 在计算 Query 和 Document 之间的相似度时,本质上是依次计算 Query 中每个单词和 Document 的相关性,然后对每个单词的相关性进行加权求和。BM25 算法一般可以表示为如下形式:
上式中, q 和 d 分别表示用来计算相似度的 Query 和 Document, qi 表示 q 的第 i 个单词, R(qi, d) 表示单词 qi 和文档 d 的相关性, Wi 表示单词 qi 的权重,计算得到的 score(q, d) 表示 q 和 d 的相关性得分,得分越高表示 q 和 d 越相似。 Wi 和 R(qi, d) 一般可以表示为如下形式:
其中, N 表示总文档数, N(qi) 表示包含单词 qi 的文档数, tf(qi, d) 表示 qi 在文档 d 中的词频, Ld 表示文档 d 的长度, Lavg 表示平均文档长度, k1 和 b 是分别用来控制 tf(qi, d) 和 Ld 对得分影响的超参数。
稀疏向量生成
在检索场景中,为了让 BM25 算法的 Score 方便进行计算,通常分别对 Document 和 Query 进行编码,然后通过 点积 的方式计算出两者的相似度。得益于 BM25 原理的特性,其原生支持将 Score 拆分为两部分 Sparse Vector,DashText 提供了encode_document
以及encode_query
两个接口来分别实现这两部分向量的生成,其生成链路如下图所示:
最终生成的稀疏向量可表示为:
Score/距离计算
生成 d 和 q 的稀疏向量后,就可以通过简单的点积进行距离计算,即将相同单词上的值对应相乘再求和,通过稀疏向量计算距离的方式如下所示:
上述计算方式本质上是通过点积来计算的, score 越大表示越相似,如果需要结合 Dense Vector 一起进行距离度量时,需要对齐距离度量方式。也就是说,在结合 Dense Vector+Sparse Vector 的场景中,距离计算只支持点积度量方式。
如何自训练模型
考虑到内置的 BM25 Model 是基于通用语料(中文Wiki语料)训练得到,在特定领域下通常不能表现出最佳的效果。因此,在一些特定场景下,通常建议训练自定义 BM25 模型。使用 DashText 来训练自定义模型时一般需要遵循以下步骤:
Step1:确认使用场景
当准备使用 SparseVector 来进行信息检索时,应提前考虑当前场景下的 Query 以及 Document 来源,通常需要提前准备好一定数量 Document 来入库,这些 Document 通常需要和特定的业务场景直接相关。
Step2:准备语料
根据 BM25 原理,语料直接决定了 BM25 模型的参数。通常应按照以下几个原则来准备语料:
一般情况下,如无特殊要求或限制,可以直接将 Step1 准备的一系列 Document 组织为语料即可。
Step3:准备 Tokenizer
Tokenizer 决定了分词的结果,分词的结果则直接影响 Sparse Vector 的生成,在特定领域下使用自定义 Tokenizer 会达到更好的效果。DashText 提供了两种扩展 Tokenizer 的方式:
Python 示例:
from dashtext import TextTokenizer, SparseVectorEncoder
my_tokenizer = TextTokenizer.from_pretrained(model_name='Jieba', dict='dict.txt')
my_encoder = SparseVectorEncoder(tokenize_function=my_tokenizer.tokenize)
复制代码
Python 示例:
from dashtext import SparseVectorEncoder
from transformers import BertTokenizer
my_tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
my_encoder = SparseVectorEncoder(tokenize_function=my_tokenizer.tokenize)
复制代码
Java 示例:
import com.aliyun.dashtext.common.DashTextException;
import com.aliyun.dashtext.common.ErrorCode;
import com.aliyun.dashtext.encoder.SparseVectorEncoder;
import com.aliyun.dashtext.tokenizer.BaseTokenizer;
import java.util.Arrays;
import java.util.List;
import java.util.stream.Collectors;
public class Main {
public static class MyTokenizer implements BaseTokenizer {
@Override
public List<String> tokenize(String s) throws DashTextException {
if (s == null) {
throw new DashTextException(ErrorCode.INVALID_ARGUMENT);
}
// 使用正则表达式将文本按空白符和标点符号分割,并转换为小写
return Arrays.stream(s.split("\\s+|(?<!\\d)[.,](?!\\d)"))
.map(String::toLowerCase)
.filter(token -> !token.isEmpty()) // 过滤掉空字符串
.collect(Collectors.toList());
}
}
public static void main(String[] args) {
SparseVectorEncoder encoder = new SparseVectorEncoder(new MyTokenizer());
}
}
复制代码
Step4:训练模型
实际上,这里的"训练"本质上是一个"统计"参数的过程。由于训练自定义模型的过程中包含着大量 Tokenizing/Hashing 过程,所以可能会耗费一定的时间。DashText 提供了SparseVectorEncoder.train
接口可以用来训练模型。
Step5:调参优化(可选)
模型训练完成后,可以准备部分验证数据集以及通过微调 k1 和 b 来达到最佳的召回效果。调节 k1 和 b 一般需要遵循以下原则:
一般情况下,如无特殊要求或限制,不需要调整 k1 和 b 。
Step6:Finetune 模型(可选)
实际场景下,可能会存在需要补充训练语料来增量式地更新 BM25 模型参数的情况。DashText 的SparseVectorEncoder.train
接口原生支持模型的增量更新。需要注意的是,模型更改之后,使用旧模型进行编码并已入库的向量就失去了时效性,一般需要重新入库。
示例代码
以下是一个简单完整的自训练模型示例。
Python 示例:
from dashtext import SparseVectorEncoder
from pydantic import BaseModel
from typing import Dict, List
class Result(BaseModel):
doc: str
score: float
def calculate_score(query_vector: Dict[int, float], document_vector: Dict[int, float]) -> float:
score = 0.0
for key, value in query_vector.items():
if key in document_vector:
score += value * document_vector[key]
return score
# 创建空SparseVectorEncoder(可以设置自定义Tokenizer)
encoder = SparseVectorEncoder()
# step1: 准备语料以及Documents
corpus_document: List[str] = [
"The quick brown fox rapidly and agilely leaps over the lazy dog that lies idly by the roadside.",
"Never jump over the lazy dog quickly",
"A fox is quick and jumps over dogs",
"The quick brown fox",
"Dogs are domestic animals",
"Some dog breeds are quick and jump high",
"Foxes are wild animals and often have a brown coat",
]
# step2: 训练BM25 Model
encoder.train(corpus_document)
# step3: 调参优化BM25 Model
query: str = "quick brown fox"
print(f"query: {query}")
k1s = [1.0, 1.5]
bs = [0.5, 0.75]
for k1, b in zip(k1s, bs):
print(f"current k1: {k1}, b: {b}")
encoder.b = b
encoder.k1 = k1
query_vector = encoder.encode_queries(query)
results: List[Result] = []
for idx, doc in enumerate(corpus_document):
doc_vector = encoder.encode_documents(doc)
score = calculate_score(query_vector, doc_vector)
results.append(Result(doc=doc, score=score))
results.sort(key=lambda r: r.score, reverse=True)
for result in results:
print(result)
# step4: 选择最优参数并保存模型
encoder.b = 0.75
encoder.k1 = 1.5
encoder.dump("./model.json")
# step5: 后续使用时可以加载模型
new_encoder = SparseVectorEncoder()
bm25_model_path = "./model.json"
new_encoder.load(bm25_model_path)
# step6: 对模型进行finetune并保存
extra_corpus: List[str] = [
"The fast fox jumps over the lazy, chubby dog",
"A swift fox hops over a napping old dog",
"The quick fox leaps over the sleepy, plump dog",
"The agile fox jumps over the dozing, heavy-set dog",
"A speedy fox vaults over a lazy, old dog lying in the sun"
]
new_encoder.train(extra_corpus)
new_bm25_model_path = "new_model.json"
new_encoder.dump(new_bm25_model_path)
复制代码
Java 示例:
import com.aliyun.dashtext.encoder.SparseVectorEncoder;
import java.io.*;
import java.util.*;
public class Main {
public static class Result {
public String doc;
public float score;
public Result(String doc, float score) {
this.doc = doc;
this.score = score;
}
@Override
public String toString() {
return String.format("Result(doc=%s, score=%f)", doc, score);
}
}
public static float calculateScore(Map<Long, Float> queryVector, Map<Long, Float> documentVector) {
float score = 0.0f;
for (Map.Entry<Long, Float> entry : queryVector.entrySet()) {
if (documentVector.containsKey(entry.getKey())) {
score += entry.getValue() * documentVector.get(entry.getKey());
}
}
return score;
}
public static void main(String[] args) throws IOException {
// 创建空SparseVectorEncoder(可以设置自定义Tokenizer)
SparseVectorEncoder encoder = new SparseVectorEncoder();
// step1: 准备语料以及Documents
List<String> corpusDocument = Arrays.asList(
"The quick brown fox rapidly and agilely leaps over the lazy dog that lies idly by the roadside.",
"Never jump over the lazy dog quickly",
"A fox is quick and jumps over dogs",
"The quick brown fox",
"Dogs are domestic animals",
"Some dog breeds are quick and jump high",
"Foxes are wild animals and often have a brown coat"
);
// step2: 训练BM25 Model
encoder.train(corpusDocument);
// step3: 调参优化BM25 Model
String query = "quick brown fox";
System.out.println("query: " + query);
float[] k1s = {1.0f, 1.5f};
float[] bs = {0.5f, 0.75f};
for (int i = 0; i < k1s.length; i++) {
float k1 = k1s[i];
float b = bs[i];
System.out.println("current k1: " + k1 + ", b: " + b);
encoder.setB(b);
encoder.setK1(k1);
Map<Long, Float> queryVector = encoder.encodeQueries(query);
List<Result> results = new ArrayList<>();
for (String doc : corpusDocument) {
Map<Long, Float> docVector = encoder.encodeDocuments(doc);
float score = calculateScore(queryVector, docVector);
results.add(new Result(doc, score));
}
results.sort((r1, r2) -> Float.compare(r2.score, r1.score));
for (Result result : results) {
System.out.println(result);
}
}
// step4: 选择最优参数并保存模型
encoder.setB(0.75f);
encoder.setK1(1.5f);
encoder.dump("./model.json");
// step5: 后续使用时可以加载模型
SparseVectorEncoder newEncoder = new SparseVectorEncoder();
newEncoder.load("./model.json");
// step6: 对模型进行finetune并保存
List<String> extraCorpus = Arrays.asList(
"The fast fox jumps over the lazy, chubby dog",
"A swift fox hops over a napping old dog",
"The quick fox leaps over the sleepy, plump dog",
"The agile fox jumps over the dozing, heavy-set dog",
"A speedy fox vaults over a lazy, old dog lying in the sun"
);
newEncoder.train(extraCorpus);
newEncoder.dump("./new_model.json");
}
}
复制代码
API 参考
DashText API 详情可参考:https://pypi.org/project/dashtext/
评论