写点什么

解析 WeNet 云端推理部署代码

  • 2021 年 12 月 14 日
  • 本文字数:6334 字

    阅读完需:约 21 分钟

摘要:WeNet 是一款开源端到端 ASR 工具包,它与 ESPnet 等开源语音项目相比,最大的优势在于提供了从训练到部署的一整套工具链,使 ASR 服务的工业落地更加简单。

 

本文分享自华为云社区《WeNet云端推理部署代码解析》,作者:xiaoye0829 。

 

WeNet 是一款开源端到端 ASR 工具包,它与 ESPnet 等开源语音项目相比,最大的优势在于提供了从训练到部署的一整套工具链,使 ASR 服务的工业落地更加简单。如图 1 所示,WeNet 工具包完全依赖于 PyTorch 生态:使用 TorchScript 进行模型开发,使用 Torchaudio 进行动态特征提取,使用 DistributedDataParallel 进行分布式训练,使用 torch JIT(Just In Time)进行模型导出,使用 LibTorch 作为生产环境运行时。本系列将对 WeNet 云端推理部署代码进行解析。


图 1:WeNet 系统设计[1]

1. 代码结构


WeNet 云端推理和部署代码位于 wenet/runtime/server/x86 路径下,编程语言为 C++,其结构如下所示:



​其中:

  • 语音文件读入与特征提取相关代码位于 frontend 文件夹下;

  • 端到端模型导入、端点检测与语音解码识别相关代码位于 decoder 文件夹下,WeNet 支持 CTC prefix beam search 和融合了 WFST 的 CTC beam search 这两种解码算法,后者的实现大量借鉴了 Kaldi,相关代码放在 kaldi 文件夹下;

  • 在服务化方面,WeNet 分别实现了基于 WebSocket 和基于 gRPC 的两套服务端与客户端,基于 WebSocket 的实现位于 websocket 文件夹下,基于 gRPC 的实现位于 grpc 文件夹下,两种实现的入口 main 函数代码都位于 bin 文件夹下。

  • 日志、计时、字符串处理等辅助代码位于 utils 文件夹下。


WeNet 提供了 CMakeLists.txt 和 Dockerfile,使得用户能方便地进行项目编译和镜像构建。

2. 前端:frontend 文件夹

1)语音文件读入


WeNet 只支持 44 字节 header 的 wav 格式音频数据,wav header 定义在 WavHeader 结构体中,包括音频格式、声道数、采样率等音频元信息。WavReader 类用于语音文件读入,调用 fopen 打开语音文件后,WavReader 先读入 WavHeader 大小的数据(也就是 44 字节),再根据 WavHeader 中的元信息确定待读入音频数据的大小,最后调用 fread 把音频数据读入 buffer,并通过 static_cast 把数据转化为 float 类型。


struct WavHeader {  char riff[4];  // "riff"  unsigned int size;  char wav[4];  // "WAVE"  char fmt[4];  // "fmt "  unsigned int fmt_size;  uint16_t format;  uint16_t channels;  unsigned int sample_rate;  unsigned int bytes_per_second;  uint16_t block_size;  uint16_t bit;  char data[4];  // "data"  unsigned int data_size;};
复制代码


​这里存在的一个风险是,如果 WavHeader 中存放的元信息有误,则会影响到语音数据的正确读入。

2)特征提取


WeNet 使用的特征是 fbank,通过 FeaturePipelineConfig 结构体进行特征设置。默认帧长为 25ms,帧移为 10ms,采样率和 fbank 维数则由用户输入。


用于特征提取的类是 FeaturePipeline。为了同时支持流式与非流式语音识别,FeaturePipeline 类中设置了 input_finished_属性来标志输入是否结束,并通过 set_input_finished()成员函数来对 input_finished_属性进行操作。


提取出来的 fbank 特征放在 feature_queue_中,feature_queue_的类型是 BlockingQueue<std::vector<float>>。BlockingQueue 类是 WeNet 实现的一个阻塞队列,初始化的时候需要提供队列的容量(capacity),通过 Push()函数向队列中增加特征,通过 Pop()函数从队列中读取特征:

  • 当 feature_queue_中的 feature 数量超过 capacity,则 Push 线程被挂起,等待 feature_queue_.Pop()释放出空间。

  • 当 feature_queue_为空,则 Pop 线程被挂起,等待 feature_queue_.Push()。

线程的挂起和恢复是通过 C++标准库中的线程同步原语 std::mutex、std::condition_variable 等实现。

线程同步还用在 AcceptWaveform 和 ReadOne 两个成员函数中,AcceptWaveform 把语音数据提取得到的 fbank 特征放到 feature_queue_中,ReadOne 成员函数则把特征从 feature_queue_中读出,是经典的生产者消费者模式。

3. 解码器:decoder 文件夹


1)TorchAsrModel


通过 torch::jit::load 对存在磁盘上的模型进行反序列化,得到一个 ScriptModule 对象。


torch::jit::script::Module model = torch::jit::load(model_path);
复制代码


2)SearchInterface


WeNet 推理支持的解码方式都继承自基类 SearchInterface,如果要新增解码算法,则需继承 SearchInterface 类,并提供该类中所有纯虚函数的实现,包括:


// 解码算法的具体实现virtual void Search(const torch::Tensor& logp) = 0;// 重置解码过程virtual void Reset() = 0;// 结束解码过程virtual void FinalizeSearch() = 0;// 解码算法类型,返回一个枚举常量SearchTypevirtual SearchType Type() const = 0;// 返回解码输入virtual const std::vector<std::vector<int>>& Inputs() const = 0;// 返回解码输出virtual const std::vector<std::vector<int>>& Outputs() const = 0;// 返回解码输出对应的似然值virtual const std::vector<float>& Likelihood() const = 0;// 返回解码输出对应的次数virtual const std::vector<std::vector<int>>& Times() const = 0;
复制代码


​目前 WeNet 只提供了 SearchInterface 的两种子类实现,也即两种解码算法,分别定义在 CtcPrefixBeamSearch 和 CtcWfstBeamSearch 两个类中。

3)CtcEndpoint


WeNet 支持语音端点检测,提供了一种基于规则的实现方式,用户可以通过 CtcEndpointConfig 结构体和 CtcEndpointRule 结构体进行规则配置。WeNet 默认的规则有三条:

  • 检测到了 5s 的静音,则认为检测到端点;

  • 解码出了任意时长的语音后,检测到了 1s 的静音,则认为检测到端点;

  • 解码出了 20s 的语音,则认为检测到端点。

一旦检测到端点,则结束解码。另外,WeNet 把解码得到的空白符(blank)视作静音。

4)TorchAsrDecoder


WeNet 提供的解码器定义在 TorchAsrDecoder 类中。如图 3 所示,WeNet 支持双向解码,即叠加从左往右解码和从右往左解码的结果。在 CTC beam search 之后,用户还可以选择进行 attention 重打分。


图 2:WeNet 解码计算流程[2]

可以通过 DecodeOptions 结构体进行解码参数配置,包括如下参数:


struct DecodeOptions {  int chunk_size = 16;  int num_left_chunks = -1;  float ctc_weight = 0.0;  float rescoring_weight = 1.0;  float reverse_weight = 0.0;  CtcEndpointConfig ctc_endpoint_config;  CtcPrefixBeamSearchOptions ctc_prefix_search_opts;  CtcWfstBeamSearchOptions ctc_wfst_search_opts;};
复制代码


​其中,ctc_weight 表示 CTC 解码权重,rescoring_weight 表示重打分权重,reverse_weight 表示从右往左解码权重。最终解码打分的计算方式为:


final_score = rescoring_weight * rescoring_score + ctc_weight * ctc_score;rescoring_score = left_to_right_score * (1 - reverse_weight) +right_to_left_score * reverse_weight
复制代码


​TorchAsrDecoder 对外提供的解码接口是 Decode(),重打分接口是 Rescoring()。Decode()返回的是枚举类型 DecodeState,包括三个枚举常量:kEndBatch,kEndpoint 和 kEndFeats,分别表示当前批数据解码结束、检测到端点、所有特征解码结束。


为了支持长语音识别,WeNet 还提供了连续解码接口 ResetContinuousDecoding(),它与解码器重置接口 Reset()的区别在于:连续解码接口会记录全局已经解码的语音帧数,并保留当前 feature_pipeline_的状态。


由于流式 ASR 服务需要在客户端和服务端之间进行双向的流式数据传输,WeNet 实现了两种支持双向流式通信的服务化接口,分别基于 WebSocket 和 gRPC。

4. 基于 WebSocket

1)WebSocket 简介


WebSocket 是基于 TCP 的一种新的网络协议,与 HTTP 协议不同,WebSocket 允许服务器主动发送信息给客户端。 在连接建立后,客户端和服务端可以连续互相发送数据,而无需在每次发送数据时重新发起连接请求。因此大大减小了网络带宽的资源消耗 ,在性能上更有优势。


WebSocket 支持文本和二进制两种格式的数据传输 。

2)WeNet 的 WebSocket 接口


WeNet 使用了 boost 库的 WebSocket 实现,定义了 WebSocketClient(客户端)和 WebSocketServer(服务端)两个类。


在流式 ASR 过程中,WebSocketClient 给 WebSocketServer 发送数据可以分为三个步骤:1)发送开始信号与解码配置;2)发送二进制语音数据:pcm 字节流;3)发送停止信号。从 WebSocketClient::SendStartSignal()和 WebSocketClient::SendEndSignal()可以看到,开始信号、解码配置和停止信号都是包装在 json 字符串中,通过 WebSocket 文本格式传输。pcm 字节流则通过 WebSocket 二进制格式进行传输。


void WebSocketClient::SendStartSignal() {  // TODO(Binbin Zhang): Add sample rate and other setting surpport  json::value start_tag = {{"signal", "start"},                           {"nbest", nbest_},                           {"continuous_decoding", continuous_decoding_}};  std::string start_message = json::serialize(start_tag);  this->SendTextData(start_message);}
void WebSocketClient::SendEndSignal() { json::value end_tag = {{"signal", "end"}}; std::string end_message = json::serialize(end_tag); this->SendTextData(end_message);}
复制代码


​WebSocketServer 在收到数据后,需要先判断收到的数据是文本还是二进制格式:如果是文本数据,则进行 json 解析,并根据解析结果进行解码配置、启动或停止,处理逻辑定义在 ConnectionHandler::OnText()函数中。如果是二进制数据,则进行语音识别,处理逻辑定义在 ConnectionHandler::OnSpeechData()中。

3)缺点


WebSocket 需要开发者在 WebSocketClient 和 WebSocketServer 写好对应的消息构造和解析代码,容易出错。另外,从以上代码来看,服务需要借助 json 格式来序列化和反序列化数据,效率没有 protobuf 格式高。


对于这些缺点,gRPC 框架提供了更好的解决方法。

5. 基于 gRPC

1)gRPC 简介


gRPC 是谷歌推出的开源 RPC 框架,使用 HTTP2 作为网络传输协议,并使用 protobuf 作为数据交换格式,有更高的数据传输效率。在 gRPC 框架下,开发者只需通过一个.proto 文件定义好 RPC 服务(service)与消息(message),便可通过 gRPC 提供的代码生成工具(protoc compiler)自动生成消息构造和解析代码,使开发者能更好地聚焦于接口设计本身。


进行 RPC 调用时,gRPC Stub(客户端)向 gRPC Server(服务端)发送.proto 文件中定义的 Request 消息,gRPC Server 在处理完请求之后,通过.proto 文件中定义的 Response 消息将结果返回给 gRPC Stub。


gRPC 具有跨语言特性,支持不同语言写的微服务进行互动,比如说服务端用 C++实现,客户端用 Ruby 实现。protoc compiler 支持 12 种语言的代码生成。


图 3:gRPC Server 和 gRPC Stub 交互[3]

2)WeNet 的 proto 文件


WeNet 定义的服务为 ASR,包含一个 Recognize 方法,该方法的输入(Request)、输出(Response)都是流式数据(stream)。在使用 protoc compiler 编译 proto 文件后,会得到 4 个文件:wenet.grpc.pb.h,wenet.grpc.pb.cc,wenet.pb.h,wenet.pb.cc。其中,wenet.pb.h/cc 中存储了 protobuf 数据格式的定义,wenet.grpc.pb.h 中存储了 gRPC 服务端/客户端的定义。通过在代码中包括 wenet.pb.h 和 wenet.grpc.pb.h 两个头文件,开发者可以直接使用 Request 消息和 Response 消息类,访问其字段。


service ASR {  rpc Recognize (stream Request) returns (stream Response) {}}
message Request {
message DecodeConfig { int32 nbest_config = 1; bool continuous_decoding_config = 2; }
oneof RequestPayload { DecodeConfig decode_config = 1; bytes audio_data = 2; }}
message Response {
message OneBest { string sentence = 1; repeated OnePiece wordpieces = 2; }
message OnePiece { string word = 1; int32 start = 2; int32 end = 3; }
enum Status { ok = 0; failed = 1; }
enum Type { server_ready = 0; partial_result = 1; final_result = 2; speech_end = 3; }
Status status = 1; Type type = 2; repeated OneBest nbest = 3;}
复制代码


3)WeNet 的 gRPC 实现


WeNet gRPC 服务端定义了 GrpcServer 类,该类继承自 wenet.grpc.pb.h 中的纯虚基类 ASR::Service。


语音识别的入口函数是 GrpcServer::Recognize,该函数初始化一个 GRPCConnectionHandler 实例来进行语音识别,并通过 ServerReaderWriter 类的 stream 对象来传递输入输出。


Status GrpcServer::Recognize(ServerContext* context,                             ServerReaderWriter<Response, Request>* stream) {  LOG(INFO) << "Get Recognize request" << std::endl;  auto request = std::make_shared<Request>();  auto response = std::make_shared<Response>();  GrpcConnectionHandler handler(stream, request, response, feature_config_,                                decode_config_, symbol_table_, model_, fst_);  std::thread t(std::move(handler));  t.join();  return Status::OK;}
复制代码


​WeNet gRPC 客户端定义了 GrpcClient 类。客户端在建立与服务端的连接时需实例化 ASR::Stub,并通过 ClientReaderWriter 类的 stream 对象,实现双向流式通信。


void GrpcClient::Connect() {  channel_ = grpc::CreateChannel(host_ + ":" + std::to_string(port_),                                 grpc::InsecureChannelCredentials());  stub_ = ASR::NewStub(channel_);  context_ = std::make_shared<ClientContext>();  stream_ = stub_->Recognize(context_.get());  request_ = std::make_shared<Request>();  response_ = std::make_shared<Response>();  request_->mutable_decode_config()->set_nbest_config(nbest_);  request_->mutable_decode_config()->set_continuous_decoding_config(      continuous_decoding_);  stream_->Write(*request_);}
复制代码


​grpc_client_main.cc 中,客户端分段传输语音数据,每 0.5s 进行一次传输,即对于一个采样率为 8k 的语音文件来说,每次传 4000 帧数据。为了减小传输数据的大小,提升数据传输速度,先在客户端将 float 类型转为 int16_t,服务端在接受到数据后,再将 int16_t 转为 float。c++中 float 为 32 位。


int main(int argc, char *argv[]) {  ...  // Send data every 0.5 second  const float interval = 0.5;  const int sample_interval = interval * sample_rate;  for (int start = 0; start < num_sample; start += sample_interval) {    if (client.done()) {      break;    }    int end = std::min(start + sample_interval, num_sample);    // Convert to short    std::vector<int16_t> data;    data.reserve(end - start);    for (int j = start; j < end; j++) {      data.push_back(static_cast<int16_t>(pcm_data[j]));    }    // Send PCM data    client.SendBinaryData(data.data(), data.size() * sizeof(int16_t));    ...}
复制代码


总结


本文主要对 WeNet 云端部署代码进行解析,介绍了 WeNet 基于 WebSocket 和基于 gRPC 的两种服务化接口。


WeNet 代码结构清晰,简洁易用,为语音识别提供了从训练到部署的一套端到端解决方案,大大促进了工业落地效率,是非常值得借鉴学习的语音开源项目。

参考


[1] https://grpc.io/docs/what-is-grpc/introduction/

[2]WeNet: Production First and Production Ready End-to-End Speech Recognition Toolkit

[3]WeNet源码

[4]WeNet: Production First and Production Ready End-to-End Speech Recognition Toolkit

[5] U2++: Unified Two-pass Bidirectional End-to-end Model for Speech Recognition


点击关注,第一时间了解华为云新鲜技术~

发布于: 3 小时前阅读数: 5
用户头像

提供全面深入的云计算技术干货 2020.07.14 加入

华为云开发者社区,提供全面深入的云计算前景分析、丰富的技术干货、程序样例,分享华为云前沿资讯动态,方便开发者快速成长与发展,欢迎提问、互动,多方位了解云计算! 传送门:https://bbs.huaweicloud.com/

评论

发布
暂无评论
解析WeNet云端推理部署代码