写点什么

MindIE-LLM ATB 模型推理全流程解析

作者:AI布道Mr.Jin
  • 2025-06-26
    上海
  • 本文字数:6293 字

    阅读完需:约 21 分钟

最近,有很多小伙伴问我,如果他们想自己基于 MindIE 镜像中的文件适配新模型,可以怎么做?


为了实现这个目标,首先需要了解 MindIE-LLM 模型在推理过程中的代码调用流程,然后根据新模型的算法进行适配。

背景知识

MindIE-LLM 组件采用 ATB 算子构建模型。ATB 全称 Ascend transformer boost,是一款高效、可靠的加速库,基于华为 Ascend AI 处理器,专门为 Transformer 模型的训练和推理而设计。开发者可以使用 ATB 算子组图,实现大模型的整图高性能推理,详情可以参考官网链接:https://www.hiascend.com/document/detail/zh/CANNCommunityEdition/82RC1alpha002/acce/ascendtb/ascendtb_0001.html

代码入口

本文以 llama 模型为例,从入口脚本run_pa.py开始,分析模型路由、模型实例化(权重导入)和图构建推理的过程。


MindIE-LLM ATB 模型的推理入口文件在官网 MindIE 镜像的这个位置:/usr/local/Ascend/atb-models/examples/run_pa.py 。这个文件的核心代码如下:


pa_runner = PARunner(**input_dict)print_log(rank, logger.info, f'pa_runner: {pa_runner}')pa_runner.warm_up()
infer_params = { "inputs": infer_inputs, "batch_size": args.max_batch_size, "max_output_length": args.max_output_length, "ignore_eos": args.ignore_eos, "is_chat_model": args.is_chat_model}generate_texts, token_nums, _ = pa_runner.infer(**infer_params)
复制代码


pa_runner实例化的过程中包含了模型类的路由、权重导入和计算图构建,接下来我们逐个分析。

模型类路由

这部分的功能是根据用户传入的 config 参数获取模型类。



上图是模型类路由的代码调用流程。PARunnerinit函数会调用self.model = ModelRunner()进行模型类的获取。ModelRunner定义在model_runner.py文件中。ModelRunnerinit函数调用router_ins = get_model获取模型信息。我们来看一下get_model()函数:


...router_path = f"atb_llm.models.{model_type}.router_{model_type}"if model_type == "qwen2_moe" or model_type == "qwen3_moe":    model_type = model_type.replace('_', '')if model_type == "qwen2_audio":    model_type = model_type.replace('_', '')if model_type == "qwen2_vl":    model_type = model_type.replace('_', '')if model_type == "minicpm_qwen2_v2":    model_type = model_type.replace('_', '')router = importlib.import_module(router_path)router_cls = getattr(router, f"{model_type.capitalize()}Router")router_ins = router_cls(    model_name_or_path,    config_dict,    is_flash_causal_lm,    load_tokenizer,    max_position_embeddings,    revision,    tokenizer_path,    trust_remote_code,    enable_atb_torch,    enable_edge,    enable_refactor,    llm_config)return router_ins
复制代码


从上面代码的第 1 行可以看到,这个函数根据 config 文件中的model_type找到了 llama 模型路由的位置atb_llm\models\llama\router_llama.py,以及router_cls=LlamaRouter()


然后回到ModelRunnerinit函数中,运行了self.model_cls = router_ins.model_cls来获得模型类。LlamaRoutermodel_cls()函数定义在它的基类BaseRouter里面:


def get_model_cls(self):  ...    model_cls_name = f"{self.model_type_cap}ForCausalLM"    if self.enable_atb_torch:        model_cls_name += "ATB"    if self.is_flash_causal_lm:        model_cls_name = "Flash" + model_cls_name    if self.enable_refactor:        model_cls_name += "V2"    return getattr(module, model_cls_name)
复制代码


可以看到,这段代码根据model_cls_name找到了模型类FlashLlamaForCausalLMATB以及它的文件名flash_causal_llama_atb.py。需要注意的是,此时只是获取了模型类,还没有做实例化。


打开代码仓的同学应该发现了,router_llama.pyflash_causal_llama_atb.py都放在atb_models\atb_llm\models\llama目录下。所以,如果你想重新适配一个模型,那么也需要在atb_models\atb_llm\models目录下创建一个新模型对应的目录,并且实现这些文件。

模型实例化 &权重导入

PARunner获取到模型类之后,继续调用self.model.load_weights把权重加载到模型中(同时完成了模型实例化),代码调用流程如下:



load_weights函数的主要逻辑如下,包括模型的实例化和模型下发到 device:


self.model = self.model_cls(...)...self.model.to(weights.device)
复制代码


FlashLlamaForCausalLMATB的初始化函数中调用了self.model = LlamaModelATB()构建模型,我们继续看一下LlamaModelATB的初始化函数:


...is_parallel = config.vocab_size >= LLAMA_EMBEDDING_PARALLEL_THRESHOLDsuper().__init__(config, weights, model_prefix, lm_head_prefix, is_parallel, is_fa, backend)
self.layers = nn.ModuleList( [LlamaLayer(layer_idx, config, weights, model_prefix, self.is_fa, self.backend, speculate_enable) \ for layer_idx in range(config.num_hidden_layers)])
linear_info = LmHeadLinearInfo()linear_info.lm_head_name = lm_head_prefixself.norm = BaseRMSNorm(f"{model_prefix}.norm", config, weights, linear_info)
复制代码


self.layers又调用了class LlamaLayer定义每一层的结构,详情如下:


...# 模型结构self.self_attn = LlamaAttention(    config=config, weights=weights, prefix=f"{prefix}.self_attn", norm_prefix=f"{prefix}.input_layernorm", \    is_fa=self.is_fa, backend=backend, speculate_enable=self.speculate_enable)
self.mlp = BaseMLP( prefix=f"{prefix}.mlp", config=config, weights=weights, norm_prefix=f"{prefix}.post_attention_layernorm", backend=backend)
self.input_layernorm = BaseRMSNorm( f"{prefix}.input_layernorm", config, weights, self.self_attn.linear_info)
self.post_attention_layernorm = BaseRMSNorm( f"{prefix}.post_attention_layernorm", config, weights, self.mlp.linear_info)
复制代码


可以看到,上面把 transformer 层中的 attention、mlp 和 norm 层都进行了定义,如果继续观察每一层的初始化函数,可以发现是调用了 pytorch 的 linear 算子接口或者nn.Parameter来加载权重,然后把线性层信息保存到self.linear_info变量,下一步进行图构建会用到这个变量。

计算图构建


ModelRunner.load_weights完成权重加载后,继续调用self.model.init_graph()进行 ATB 算子的调用和计算图构建。self.model对应的是FlashLlamaForCausalLMATB,其init_graph函数继承自基类FlashForCausalLMATB


def init_graph(self):    """Initialze weight, prefill graph and decode graph."""    # 获取权重键值对    self.weight = self.get_weights()    # 创建atb graph    self.prefill_graph = AtbGraph(f"{self.name}_prefill_graph")    self.build_graph(self.prefill_graph, is_prefill=True)    self.decode_graph = AtbGraph(f"{self.name}_decode_graph")    self.build_graph(self.decode_graph, is_prefill=False)
复制代码


可以看到,这个函数初始化了 2 个AtbGraph,分别对应首 token 计算图和增量计算图。AtbGraph继承自atb._GraphOperation,是 C++的 pybind 接口,目前这部分代码没有开源。初始化 ATB 图后,又调用了self.build_graphFlashLlamaForCausalLMATB.build_graph()定义如下:


def build_graph(self, graph, is_prefill):    # 设置输入输出    kv_cache_names = []    for i in range(self.config.num_hidden_layers):        kv_cache_names.extend([f"layer_{i}_k_cache", f"layer_{i}_v_cache"])    graph.add_input_output(        input=list(self.weight.keys()) + kv_cache_names + self.get_in_tensor_names(is_prefill),        output=self.get_out_tensor_names())
# 增加图节点 self.model.build_graph(graph, is_prefill) self.build_lm_head(graph, is_prefill)
# 构图 graph.execute_as_single = False graph.build()
复制代码


首先准备了输入输出,然后调用self.model.build_graph构图,对应的是LlamaModelATB.build_graph()


def build_graph(self, graph, is_prefill):    self.build_word_embedding_graph(graph)    self.build_positional_embedding_graph(graph)    for layer in self.layers:        layer.build_graph(graph, is_prefill)    self.norm.build_graph(graph, is_prefill)
复制代码


可以看到,代码逻辑是把每一层都 build 到graph里面去,我们继续打开LlamaLayer.build_graph()


def build_graph(self, graph, is_prefill):    self.layer_graph = AtbGraph(("prefill" if is_prefill else "decode") + f"_layer_{self.layer_id}_graph")    self.layer_graph.add_input_output(        input=self.weight_names + ["k_cache", "v_cache"] + self.get_in_tensor_names(is_prefill),        output=["layer_out"])    if self.is_reshape:        self.layer_graph.add_reshape("hidden_states", "hidden_states", self.reshape_parallel)    self.input_layernorm.build_graph(self.layer_graph, is_prefill)    self.self_attn.build_graph(self.layer_graph, is_prefill)    self.post_attention_layernorm.build_graph(self.layer_graph, is_prefill)    self.mlp.build_graph(self.layer_graph, is_prefill)    self.layer_graph.build()
graph.operations.append(self.layer_graph) graph.add_operation(self.layer_graph, self.weight_names + \ [f"layer_{self.layer_id}_k_cache", f"layer_{self.layer_id}_v_cache"] + self.get_in_tensor_names( is_prefill), ["hidden_states"])
复制代码


这段代码首先建立了一个子图self.layer_graph,然后把 norm 层、attention 层和 mlp 层都进行 build。我们以self_attn.build_graph为例继续打开:


def build_graph(self, graph, is_prefill):    atten_res_add = atb._BaseOperation(op_type="Elewise",         op_param=json.dumps({'elewiseType': 'ELEWISE_ADD'}),                                       op_name='atten_res_add')    setattr(graph, 'atten_res_add', atten_res_add)
self.build_qkv_graph(graph) self.build_rope_graph(graph) self.build_attention_graph(graph, is_prefill) self.build_dense_graph(graph, is_prefill)
graph.add_operation(graph.atten_res_add, ['hidden_states', 'dense_out'], ['hidden_states'])
复制代码


这里面又包含了 qkv 的计算、attention 计算和输出映射层的计算,我们看一下build_attention_graph是如何调用 ATB 算子的:


def build_attention_graph(self, graph, is_prefill):  ...    pa_attention_builder = CommonOpBuilderManager.get_builder(attention_param)    graph = pa_attention_builder.build(graph, attention_tensor_map)
复制代码


可以看到,这里通过CommonOpBuilderManager.get_builder获得了 pa_attention 算子的 builder。CommonOpBuilderManager是定义在common_op_builder_manager.py里面的类,它的功能是把 transformer 模型通用的算子进行管理,方便用户构建模型的时候调用。它的代码实现如下::


class CommonOpBuilderManager:    _common_op_builders = []
@classmethod def register(cls, common_op_builder_class): cls._common_op_builders.append(common_op_builder_class())
@classmethod def get_builder(cls, param: dict) -> BaseCommonOpBuilder | None: for common_op_builder in cls._common_op_builders: if common_op_builder.is_match(param): return common_op_builder print_log(ENV.rank, logger.debug, f"CommonOpBuilder not found for param: {param}") raise RuntimeError(f"CommonOpBuilder not found for param: {param}")
复制代码


注意到,它的get_builder函数可以根据传入的 param 返回对应的算子 builder。而且字典变量_common_op_builders里面的值是通过调用register进行更新的。大家可能有疑问,这个register函数是在哪里被调用的呢?实际上是在atb-models/atb_llm/common_op_builders下面的每类算子的__init__.py中执行的,比如atb_models\atb_llm\common_op_builders\attention


from atb_llm.common_op_builders.common_op_builder_manager import CommonOpBuilderManagerfrom atb_llm.common_op_builders.attention.atb_decoder_paged_attention_common_op_builder import \    ATBDecoderPagedAttentionCommonOpBuilderfrom atb_llm.common_op_builders.attention.atb_encoder_paged_attention_common_op_builder import \    ATBEncoderPagedAttentionCommonOpBuilderfrom atb_llm.common_op_builders.attention.atb_flash_attention_common_op_builder import \    ATBFlashAttentionCommonOpBuilder
CommonOpBuilderManager.register(ATBDecoderPagedAttentionCommonOpBuilder)CommonOpBuilderManager.register(ATBEncoderPagedAttentionCommonOpBuilder)CommonOpBuilderManager.register(ATBFlashAttentionCommonOpBuilder)
复制代码


对于 prefill_graph,我们结合build_attention_graph()函数中的attention_param


attention_param = {    "op_name": "attention",    "category": CommonOpBuilderType.ATTENTION,    "is_prefill": is_prefill,    "attn_type": AttnType.FLASH_ATTENTION if self.is_fa else AttnType.PAGED_ATTENTION,    "head_size": self.head_size,    "atb_reshape_and_cache_param": {},    "operation_backend": OperationBackend.ATB,    "atb_attention_param": self._get_atb_attention_param(is_prefill)}
复制代码


以及ATBEncoderPagedAttentionCommonOpBuilderis_match()函数,可知获取的 op_builder 是ATBEncoderPagedAttentionCommonOpBuilder类,它的build()函数逻辑如下:


def build(self, graph: atb._GraphOperation, tensor_map: dict) -> atb._GraphOperation:    ...    # self attention    attention_op = atb._BaseOperation(        op_type="SelfAttention",        op_param=json.dumps(self.param.atb_attention_param),        op_name=f"{self.param.op_name}_SelfAttention"    )    graph.operations.append(attention_op)    ...    return graph
复制代码


可以看到,这里通过atb._BaseOperation接口调用了 atb 算子。


其他算子的调用逻辑也同理,大家可以自己查看一遍。

总结

这篇文章主要分析了 ATB 模型推理的代码调用栈,同时给出了新模型适配涉及的代码目录。ATB 模型的适配代码目录在/usr/local/Ascend/atb-models/atb_llm/models,以 llama 模型为例,/usr/local/Ascend/atb-models/atb_llm/models/llama下面包含模型路由脚本router_llama.py以及模型类的定义脚本flash_causal_llama_atb.pymodeling_llama_atb.py。如果需要适配新的模型,需要在/models下面创建新的目录并实现上述脚本内容。


MindIE-LLM 提供了构建 transformer 模型的通用算子,统一放在/usr/local/Ascend/atb-models/atb_llm/common_op_builders目录下面,每个算子都通过_libatb_torch._BaseOperation的方式调用 ATB 算子。

用户头像

还未添加个人签名 2020-11-13 加入

还未添加个人简介

评论

发布
暂无评论
MindIE-LLM ATB模型推理全流程解析_AI布道Mr.Jin_InfoQ写作社区