最近,有很多小伙伴问我,如果他们想自己基于 MindIE 镜像中的文件适配新模型,可以怎么做?
为了实现这个目标,首先需要了解 MindIE-LLM 模型在推理过程中的代码调用流程,然后根据新模型的算法进行适配。
背景知识
MindIE-LLM 组件采用 ATB 算子构建模型。ATB 全称 Ascend transformer boost,是一款高效、可靠的加速库,基于华为 Ascend AI 处理器,专门为 Transformer 模型的训练和推理而设计。开发者可以使用 ATB 算子组图,实现大模型的整图高性能推理,详情可以参考官网链接:https://www.hiascend.com/document/detail/zh/CANNCommunityEdition/82RC1alpha002/acce/ascendtb/ascendtb_0001.html 。
代码入口
本文以 llama 模型为例,从入口脚本run_pa.py
开始,分析模型路由、模型实例化(权重导入)和图构建推理的过程。
MindIE-LLM ATB 模型的推理入口文件在官网 MindIE 镜像的这个位置:/usr/local/Ascend/atb-models/examples/run_pa.py
。这个文件的核心代码如下:
pa_runner = PARunner(**input_dict)
print_log(rank, logger.info, f'pa_runner: {pa_runner}')
pa_runner.warm_up()
infer_params = {
"inputs": infer_inputs,
"batch_size": args.max_batch_size,
"max_output_length": args.max_output_length,
"ignore_eos": args.ignore_eos,
"is_chat_model": args.is_chat_model
}
generate_texts, token_nums, _ = pa_runner.infer(**infer_params)
复制代码
pa_runner
实例化的过程中包含了模型类的路由、权重导入和计算图构建,接下来我们逐个分析。
模型类路由
这部分的功能是根据用户传入的 config 参数获取模型类。
上图是模型类路由的代码调用流程。PARunner
的init
函数会调用self.model = ModelRunner()
进行模型类的获取。ModelRunner
定义在model_runner.py
文件中。ModelRunner
的init
函数调用router_ins = get_model
获取模型信息。我们来看一下get_model()
函数:
...
router_path = f"atb_llm.models.{model_type}.router_{model_type}"
if model_type == "qwen2_moe" or model_type == "qwen3_moe":
model_type = model_type.replace('_', '')
if model_type == "qwen2_audio":
model_type = model_type.replace('_', '')
if model_type == "qwen2_vl":
model_type = model_type.replace('_', '')
if model_type == "minicpm_qwen2_v2":
model_type = model_type.replace('_', '')
router = importlib.import_module(router_path)
router_cls = getattr(router, f"{model_type.capitalize()}Router")
router_ins = router_cls(
model_name_or_path,
config_dict,
is_flash_causal_lm,
load_tokenizer,
max_position_embeddings,
revision,
tokenizer_path,
trust_remote_code,
enable_atb_torch,
enable_edge,
enable_refactor,
llm_config)
return router_ins
复制代码
从上面代码的第 1 行可以看到,这个函数根据 config 文件中的model_type
找到了 llama 模型路由的位置atb_llm\models\llama\router_llama.py
,以及router_cls=LlamaRouter()
。
然后回到ModelRunner
的init
函数中,运行了self.model_cls = router_ins.model_cls
来获得模型类。LlamaRouter
的model_cls()
函数定义在它的基类BaseRouter
里面:
def get_model_cls(self):
...
model_cls_name = f"{self.model_type_cap}ForCausalLM"
if self.enable_atb_torch:
model_cls_name += "ATB"
if self.is_flash_causal_lm:
model_cls_name = "Flash" + model_cls_name
if self.enable_refactor:
model_cls_name += "V2"
return getattr(module, model_cls_name)
复制代码
可以看到,这段代码根据model_cls_name
找到了模型类FlashLlamaForCausalLMATB
以及它的文件名flash_causal_llama_atb.py
。需要注意的是,此时只是获取了模型类,还没有做实例化。
打开代码仓的同学应该发现了,router_llama.py
和flash_causal_llama_atb.py
都放在atb_models\atb_llm\models\llama
目录下。所以,如果你想重新适配一个模型,那么也需要在atb_models\atb_llm\models
目录下创建一个新模型对应的目录,并且实现这些文件。
模型实例化 &权重导入
PARunner
获取到模型类之后,继续调用self.model.load_weights
把权重加载到模型中(同时完成了模型实例化),代码调用流程如下:
load_weights
函数的主要逻辑如下,包括模型的实例化和模型下发到 device:
self.model = self.model_cls(...)
...
self.model.to(weights.device)
复制代码
FlashLlamaForCausalLMATB
的初始化函数中调用了self.model = LlamaModelATB()
构建模型,我们继续看一下LlamaModelATB
的初始化函数:
...
is_parallel = config.vocab_size >= LLAMA_EMBEDDING_PARALLEL_THRESHOLD
super().__init__(config, weights, model_prefix, lm_head_prefix, is_parallel, is_fa, backend)
self.layers = nn.ModuleList(
[LlamaLayer(layer_idx, config, weights, model_prefix, self.is_fa, self.backend, speculate_enable) \
for layer_idx in range(config.num_hidden_layers)])
linear_info = LmHeadLinearInfo()
linear_info.lm_head_name = lm_head_prefix
self.norm = BaseRMSNorm(f"{model_prefix}.norm", config, weights, linear_info)
复制代码
self.layers
又调用了class LlamaLayer
定义每一层的结构,详情如下:
...
# 模型结构
self.self_attn = LlamaAttention(
config=config, weights=weights, prefix=f"{prefix}.self_attn", norm_prefix=f"{prefix}.input_layernorm", \
is_fa=self.is_fa, backend=backend, speculate_enable=self.speculate_enable)
self.mlp = BaseMLP(
prefix=f"{prefix}.mlp", config=config, weights=weights,
norm_prefix=f"{prefix}.post_attention_layernorm", backend=backend)
self.input_layernorm = BaseRMSNorm(
f"{prefix}.input_layernorm", config, weights, self.self_attn.linear_info)
self.post_attention_layernorm = BaseRMSNorm(
f"{prefix}.post_attention_layernorm", config, weights, self.mlp.linear_info)
复制代码
可以看到,上面把 transformer 层中的 attention、mlp 和 norm 层都进行了定义,如果继续观察每一层的初始化函数,可以发现是调用了 pytorch 的 linear 算子接口或者nn.Parameter
来加载权重,然后把线性层信息保存到self.linear_info
变量,下一步进行图构建会用到这个变量。
计算图构建
ModelRunner.load_weights
完成权重加载后,继续调用self.model.init_graph()
进行 ATB 算子的调用和计算图构建。self.model
对应的是FlashLlamaForCausalLMATB
,其init_graph
函数继承自基类FlashForCausalLMATB
:
def init_graph(self):
"""Initialze weight, prefill graph and decode graph."""
# 获取权重键值对
self.weight = self.get_weights()
# 创建atb graph
self.prefill_graph = AtbGraph(f"{self.name}_prefill_graph")
self.build_graph(self.prefill_graph, is_prefill=True)
self.decode_graph = AtbGraph(f"{self.name}_decode_graph")
self.build_graph(self.decode_graph, is_prefill=False)
复制代码
可以看到,这个函数初始化了 2 个AtbGraph
,分别对应首 token 计算图和增量计算图。AtbGraph
继承自atb._GraphOperation
,是 C++的 pybind 接口,目前这部分代码没有开源。初始化 ATB 图后,又调用了self.build_graph
,FlashLlamaForCausalLMATB.build_graph()
定义如下:
def build_graph(self, graph, is_prefill):
# 设置输入输出
kv_cache_names = []
for i in range(self.config.num_hidden_layers):
kv_cache_names.extend([f"layer_{i}_k_cache", f"layer_{i}_v_cache"])
graph.add_input_output(
input=list(self.weight.keys()) + kv_cache_names + self.get_in_tensor_names(is_prefill),
output=self.get_out_tensor_names())
# 增加图节点
self.model.build_graph(graph, is_prefill)
self.build_lm_head(graph, is_prefill)
# 构图
graph.execute_as_single = False
graph.build()
复制代码
首先准备了输入输出,然后调用self.model.build_graph
构图,对应的是LlamaModelATB.build_graph()
:
def build_graph(self, graph, is_prefill):
self.build_word_embedding_graph(graph)
self.build_positional_embedding_graph(graph)
for layer in self.layers:
layer.build_graph(graph, is_prefill)
self.norm.build_graph(graph, is_prefill)
复制代码
可以看到,代码逻辑是把每一层都 build 到graph
里面去,我们继续打开LlamaLayer.build_graph()
:
def build_graph(self, graph, is_prefill):
self.layer_graph = AtbGraph(("prefill" if is_prefill else "decode") + f"_layer_{self.layer_id}_graph")
self.layer_graph.add_input_output(
input=self.weight_names + ["k_cache", "v_cache"] + self.get_in_tensor_names(is_prefill),
output=["layer_out"])
if self.is_reshape:
self.layer_graph.add_reshape("hidden_states", "hidden_states", self.reshape_parallel)
self.input_layernorm.build_graph(self.layer_graph, is_prefill)
self.self_attn.build_graph(self.layer_graph, is_prefill)
self.post_attention_layernorm.build_graph(self.layer_graph, is_prefill)
self.mlp.build_graph(self.layer_graph, is_prefill)
self.layer_graph.build()
graph.operations.append(self.layer_graph)
graph.add_operation(self.layer_graph, self.weight_names + \
[f"layer_{self.layer_id}_k_cache", f"layer_{self.layer_id}_v_cache"] + self.get_in_tensor_names(
is_prefill), ["hidden_states"])
复制代码
这段代码首先建立了一个子图self.layer_graph
,然后把 norm 层、attention 层和 mlp 层都进行 build。我们以self_attn.build_graph
为例继续打开:
def build_graph(self, graph, is_prefill):
atten_res_add = atb._BaseOperation(op_type="Elewise", op_param=json.dumps({'elewiseType': 'ELEWISE_ADD'}),
op_name='atten_res_add')
setattr(graph, 'atten_res_add', atten_res_add)
self.build_qkv_graph(graph)
self.build_rope_graph(graph)
self.build_attention_graph(graph, is_prefill)
self.build_dense_graph(graph, is_prefill)
graph.add_operation(graph.atten_res_add, ['hidden_states', 'dense_out'], ['hidden_states'])
复制代码
这里面又包含了 qkv 的计算、attention 计算和输出映射层的计算,我们看一下build_attention_graph
是如何调用 ATB 算子的:
def build_attention_graph(self, graph, is_prefill):
...
pa_attention_builder = CommonOpBuilderManager.get_builder(attention_param)
graph = pa_attention_builder.build(graph, attention_tensor_map)
复制代码
可以看到,这里通过CommonOpBuilderManager.get_builder
获得了 pa_attention 算子的 builder。CommonOpBuilderManager
是定义在common_op_builder_manager.py
里面的类,它的功能是把 transformer 模型通用的算子进行管理,方便用户构建模型的时候调用。它的代码实现如下::
class CommonOpBuilderManager:
_common_op_builders = []
@classmethod
def register(cls, common_op_builder_class):
cls._common_op_builders.append(common_op_builder_class())
@classmethod
def get_builder(cls, param: dict) -> BaseCommonOpBuilder | None:
for common_op_builder in cls._common_op_builders:
if common_op_builder.is_match(param):
return common_op_builder
print_log(ENV.rank, logger.debug, f"CommonOpBuilder not found for param: {param}")
raise RuntimeError(f"CommonOpBuilder not found for param: {param}")
复制代码
注意到,它的get_builder
函数可以根据传入的 param 返回对应的算子 builder。而且字典变量_common_op_builders
里面的值是通过调用register
进行更新的。大家可能有疑问,这个register
函数是在哪里被调用的呢?实际上是在atb-models/atb_llm/common_op_builders
下面的每类算子的__init__.py
中执行的,比如atb_models\atb_llm\common_op_builders\attention
:
from atb_llm.common_op_builders.common_op_builder_manager import CommonOpBuilderManager
from atb_llm.common_op_builders.attention.atb_decoder_paged_attention_common_op_builder import \
ATBDecoderPagedAttentionCommonOpBuilder
from atb_llm.common_op_builders.attention.atb_encoder_paged_attention_common_op_builder import \
ATBEncoderPagedAttentionCommonOpBuilder
from atb_llm.common_op_builders.attention.atb_flash_attention_common_op_builder import \
ATBFlashAttentionCommonOpBuilder
CommonOpBuilderManager.register(ATBDecoderPagedAttentionCommonOpBuilder)
CommonOpBuilderManager.register(ATBEncoderPagedAttentionCommonOpBuilder)
CommonOpBuilderManager.register(ATBFlashAttentionCommonOpBuilder)
复制代码
对于 prefill_graph,我们结合build_attention_graph()
函数中的attention_param
:
attention_param = {
"op_name": "attention",
"category": CommonOpBuilderType.ATTENTION,
"is_prefill": is_prefill,
"attn_type": AttnType.FLASH_ATTENTION if self.is_fa else AttnType.PAGED_ATTENTION,
"head_size": self.head_size,
"atb_reshape_and_cache_param": {},
"operation_backend": OperationBackend.ATB,
"atb_attention_param": self._get_atb_attention_param(is_prefill)
}
复制代码
以及ATBEncoderPagedAttentionCommonOpBuilder
的is_match()
函数,可知获取的 op_builder 是ATBEncoderPagedAttentionCommonOpBuilder
类,它的build()
函数逻辑如下:
def build(self, graph: atb._GraphOperation, tensor_map: dict) -> atb._GraphOperation:
...
# self attention
attention_op = atb._BaseOperation(
op_type="SelfAttention",
op_param=json.dumps(self.param.atb_attention_param),
op_name=f"{self.param.op_name}_SelfAttention"
)
graph.operations.append(attention_op)
...
return graph
复制代码
可以看到,这里通过atb._BaseOperation
接口调用了 atb 算子。
其他算子的调用逻辑也同理,大家可以自己查看一遍。
总结
这篇文章主要分析了 ATB 模型推理的代码调用栈,同时给出了新模型适配涉及的代码目录。ATB 模型的适配代码目录在/usr/local/Ascend/atb-models/atb_llm/models
,以 llama 模型为例,/usr/local/Ascend/atb-models/atb_llm/models/llama
下面包含模型路由脚本router_llama.py
以及模型类的定义脚本flash_causal_llama_atb.py
和modeling_llama_atb.py
。如果需要适配新的模型,需要在/models
下面创建新的目录并实现上述脚本内容。
MindIE-LLM 提供了构建 transformer 模型的通用算子,统一放在/usr/local/Ascend/atb-models/atb_llm/common_op_builders
目录下面,每个算子都通过_libatb_torch._BaseOperation
的方式调用 ATB 算子。
评论