写点什么

BEVFormer 开源算法逐行解析(二):Decoder 和 Det 部分

  • 2024-09-04
    广东
  • 本文字数:10756 字

    阅读完需:约 35 分钟

BEVFormer 开源算法逐行解析(二):Decoder 和 Det 部分

写在前面:

对于 BEVFormer 算法框架的整体理解,大家可以找到大量的资料参考,但是对于算法代码的解读缺乏详实的资料。因此,本系列的目的是结合代码实现细节、在 tensor 维度的变换中帮助读者对算法能有更直观的认识。


本系列我们将对 BEVFormer 公版代码(开源算法)进行逐行解析,以结合代码理解 Bevformer 原理,掌握算法细节,帮助读者们利用该算法框架开发感知算法。在本系列的最后笔者还将面向地平线的用户,指出地平线参考算法在开源算法基础上做出的修改及修改背后的考虑,在算法部署过程中为用户提供参考。


公版代码目录封装较好,且以注册器的方式调用模型,各个模块的调用关系可以从 configs/bevformer 中的 config 文件中清晰体现,我们以 bevformer_tiny.py 为例 3 解析代码,Encoder 部分已经发出,见《BEVFormer 开源算法逐行解析(一):Encoder 部分》,本文主要关注 BEVFormer 的 Decoder 和 Det 部分。


对代码的解析和理解主要体现在代码注释中。

1 PerceptionTransformer:

功能:


  • 将 encoder 层输出的 bev_embed 传入 decoder 中

  • 将在 BEVFormer 中定义的 query_embedding 按通道拆分成通道数相同的 query_pos 和 query,并传入 decoder 中;

  • 利用 query_pos 通过线性层 reference_points 生成 reference_points,并传入 decoder;该 reference_points 在 decoder 中的 CustimMSDeformableAttention 作为融合 bev_embed 的基准采样点,作用类似于 two-stage 目标检测中的 Region Proposal ;

  • 返回 inter_states, inter_references 给 cls_branches 和 reg_branches 分支得到目标的种类和 bboxes。


解析:


#详见《BEVFormer开源算法逐行解析(一):Encoder部分》,用于获得bev_embed#在decoder中利用CustimMSDeformableAttention将bev_embed与query融合bev_embed = self.get_bev_features(    mlvl_feats,    bev_queries,    bev_h,    bev_w,    grid_length=grid_length,    bev_pos=bev_pos,    prev_bev=prev_bev,    **kwargs)  # bev_embed shape: bs, bev_h*bev_w, embed_dims
bs = mlvl_feats[0].size(0)#object_query_embed:torch.Size([900, 512])#query_pos:torch.Size([900, 256]) #query:torch.Size([900, 256])query_pos, query = torch.split( object_query_embed, self.embed_dims, dim=1)query_pos = query_pos.unsqueeze(0).expand(bs, -1, -1)query = query.unsqueeze(0).expand(bs, -1, -1)#reference_points:torch.Size([1, 900, 3])reference_points = self.reference_points(query_pos)reference_points = reference_points.sigmoid()init_reference_out = reference_points
#query:torch.Size([900, 1, 256])query = query.permute(1, 0, 2)#query_pos:torch.Size([900, 1, 256])query_pos = query_pos.permute(1, 0, 2)#bev_embed:torch.Size([50*50, 1, 256]) bev_embed = bev_embed.permute(1, 0, 2)
#进入decoder模块!inter_states, inter_references = self.decoder( query=query, key=None, value=bev_embed, query_pos=query_pos, reference_points=reference_points, reg_branches=reg_branches, cls_branches=cls_branches, spatial_shapes=torch.tensor([[bev_h, bev_w]], device=query.device), level_start_index=torch.tensor([0], device=query.device), **kwargs)#返回inter_states, inter_references#后续用于提供给cls_branches和reg_branches分支得到目标的种类和bboxesinter_references_out = inter_references
return bev_embed, inter_states, init_reference_out, inter_references_out
复制代码

2 DetectionTransformerDecoder

功能:


  • 循环进入 6 个相同的 DetrTransformerDecoderLayer,一个 DetrTransformerDecoderLayer 包含 ('self_attn', 'norm', 'cross_attn', 'norm', 'ffn', 'norm'),每层输出 output 和 reference_points;

  • 在 6 层 DetrTransformerDecoderLayer 遍历完成后,将 6 层输出的 output 和 reference_points 输出。


解析:


#output:torch.Size([900, 1, 256])output = queryintermediate = []intermediate_reference_points = []#循环进入6个相同的DetrTransformerDecoderLayer模块for lid, layer in enumerate(self.layers):    #reference_points_input:torch.Size([1, 900, 1, 2])    #该reference_points在decoder中的CustimMSDeformableAttention作为融合bev_embed的基准采样点    reference_points_input = reference_points[..., :2].unsqueeze(        2)  # BS NUM_QUERY NUM_LEVEL 2    #进入某一层DetrTransformerDecoderLayer    output = layer(        output,        *args,        reference_points=reference_points_input,        key_padding_mask=key_padding_mask,        **kwargs)    #output:torch.Size([1, 900, 256])    output = output.permute(1, 0, 2)
if reg_branches is not None: #tmp:torch.Size([1, 900, 10]) tmp = reg_branches[lid](output)
assert reference_points.shape[-1] == 3 #new_reference_pointtorch.Size([1, 900, 3]) new_reference_points = torch.zeros_like(reference_points) new_reference_points[..., :2] = tmp[ ..., :2] + inverse_sigmoid(reference_points[..., :2]) new_reference_points[..., 2:3] = tmp[ ..., 4:5] + inverse_sigmoid(reference_points[..., 2:3])
new_reference_points = new_reference_points.sigmoid()
reference_points = new_reference_points.detach()
output = output.permute(1, 0, 2) if self.return_intermediate: intermediate.append(output) intermediate_reference_points.append(reference_points) #在6层DetrTransformerDecoderLayer遍历完成后,将6层输出的output和reference_points输出。if self.return_intermediate: return torch.stack(intermediate), torch.stack( intermediate_reference_points)
return output, reference_points
复制代码


深色代码部分生成的 reference_points 结构见下图,其中 inverse_sigmoid(pt_reference_points) 即为 reference_points/Linear(query_pos)


2.1 MultiheadAttention

功能:


  • object_query 的多头自注意力机制,如下图所示。



解析:


embed_dim = 256kdim = embed_dimvdim = embed_dimqkv_same_embed_dim = kdim == embed_dim and vdim == embed_dim  # Truenum_heads = 8dropout = 0.1batch_first = Falsehead_dim = embed_dim // num_headsassert head_dim * num_heads == embed_dim, "embed_dim must be divisible by num_heads"factory_kwargs = {'device': 'cuda', 'dtype': None}in_proj_weight = nn.Parameter(torch.empty((3 * embed_dim, embed_dim), **factory_kwargs))in_proj_bias = nn.Parameter(torch.empty(3 * embed_dim, **factory_kwargs))bias_k = bias_v = Noneadd_zero_attn = Falseout_proj = NonDynamicallyQuantizableLinear(embed_dim, embed_dim, bias=True, **factory_kwargs)attn_mask = attn_mask  # None
if batch_first: query, key, value = [x.transpose(1, 0) for x in (query, key, value)]
if not qkv_same_embed_dim: # attn_output, attn_output_weights = F.multi_head_attention_forward( # query, key, value, self.embed_dim, self.num_heads, # self.in_proj_weight, self.in_proj_bias, # self.bias_k, self.bias_v, self.add_zero_attn, # self.dropout, self.out_proj.weight, self.out_proj.bias, # training=self.training, # key_padding_mask=key_padding_mask, need_weights=need_weights, # attn_mask=attn_mask, use_separate_proj_weight=True, # q_proj_weight=self.q_proj_weight, k_proj_weight=self.k_proj_weight, # v_proj_weight=self.v_proj_weight) passelse: attn_output, attn_output_weights = F.multi_head_attention_forward( query, key, value, _embed_dim, num_heads, in_proj_weight, in_proj_bias, bias_k, bias_v, add_zero_attn, dropout, out_proj.weight, out_proj.bias, training=True, key_padding_mask=None, need_weights=True, attn_mask=mhaf_attn_mask) -------------------------------F.multi_head_attention_forward start---------------------------- out_proj_weight = out_proj.weight out_proj_bias = out_proj.bias key = key value = value embed_dim_to_check = embed_dim use_separate_proj_weight = False training = True key_padding_mask = None need_weights = True q_proj_weight, k_proj_weight, v_proj_weight = None, None, None static_k, static_v = None, None
# set up shape vars tgt_len, bsz, embed_dim = query.shape # torch.Size([900, 1, 256]) src_len, _, _ = key.shape assert embed_dim == embed_dim_to_check, \ f"was expecting embedding dimension of {embed_dim_to_check}, but got {embed_dim}" if isinstance(embed_dim, torch.Tensor): # # embed_dim can be a tensor when JIT tracing # head_dim = embed_dim.div(mhaf_num_heads, rounding_mode='trunc') pass else: head_dim = embed_dim // num_heads assert head_dim * num_heads == embed_dim, f"embed_dim {embed_dim} not divisible by num_heads {mhaf_num_heads}"
if not use_separate_proj_weight: # q, k, v = _in_projection_packed(query, key, value, in_proj_weight, in_proj_bias) # -----------_in_projection_packed start----------- # q, k, v, w, b = query, mhaf_key, mhaf_value, mhaf_in_proj_weight, mhaf_in_proj_bias # E = query.size(-1) if key is value: # if query is mhaf_key: # # self-attention # return linear(query, mhaf_in_proj_weight, mhaf_in_proj_bias).chunk(3, dim=-1) # else: # # encoder-decoder attention # w_q, w_kv = mhaf_in_proj_weight.split([E, E * 2]) # if mhaf_in_proj_bias is None: # b_q = b_kv = None # else: # b_q, b_kv = mhaf_in_proj_bias.split([E, E * 2]) # return (linear(query, w_q, b_q),) + linear(mhaf_key, w_kv, b_kv).chunk(2, dim=-1) pass else: w_q, w_k, w_v = in_proj_weight.chunk(3) if in_proj_bias is None: # b_q = b_k = b_v = None pass else: b_q, b_k, b_v = in_proj_bias.chunk(3) # return linear(query, w_q, b_q), linear(mhaf_key, w_k, b_k), linear(mhaf_value, w_v, b_v) # F.linear(x, A, b): return x @ A.T + b query, key, value = F.linear(query, w_q, b_q), F.linear(key, w_k, b_k), F.linear(value, w_v, b_v) # query + pt_query_pos query + pt_query_pos query # ------------_in_projection_packed end------------ # else: # assert q_proj_weight is not None, "use_separate_proj_weight is True but q_proj_weight is None" # assert k_proj_weight is not None, "use_separate_proj_weight is True but k_proj_weight is None" # assert v_proj_weight is not None, "use_separate_proj_weight is True but v_proj_weight is None" # if in_proj_bias is None: # b_q = b_k = b_v = None # else: # b_q, b_k, b_v = in_proj_bias.chunk(3) # q, k, v = _in_projection(query, key, value, q_proj_weight, k_proj_weight, v_proj_weight, b_q, b_k, b_v)
# # reshape q, k, v for multihead attention and make em batch first query = query.contiguous().view(tgt_len, bsz * num_heads, head_dim).transpose(0, 1) # [900, 1, 256] -> [900, 8, 32] -> [8, 900, 32] if static_k is None: key = key.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1) # [900, 8, 32] -> [8, 900, 32] # else: # # TODO finish disentangling control flow so we don't do in-projections when statics are passed # assert mhaf_static_k.size(0) == bsz * mhaf_num_heads, \ # f"expecting static_k.size(0) of {bsz * mhaf_num_heads}, but got {mhaf_static_k.size(0)}" # assert mhaf_static_k.size(2) == head_dim, \ # f"expecting static_k.size(2) of {head_dim}, but got {mhaf_static_k.size(2)}" # mhaf_key = mhaf_static_k if static_v is None: value = value.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1) # [900, 8, 32] -> [8, 900, 32] # else: # # TODO finish disentangling control flow so we don't do in-projections when statics are passed # assert mhaf_static_v.size(0) == bsz * mhaf_num_heads, \ # f"expecting static_v.size(0) of {bsz * mhaf_num_heads}, but got {mhaf_static_v.size(0)}" # assert mhaf_static_v.size(2) == head_dim, \ # f"expecting static_v.size(2) of {head_dim}, but got {mhaf_static_v.size(2)}" # mhaf_value = mhaf_static_v
# update source sequence length after adjustments src_len = key.size(1)
attn_output, attn_output_weights = _scaled_dot_product_attention(query, key, value, attn_mask, dropout) # ------------_scaled_dot_product_attention start------------ # q: Tensor, # k: Tensor, # v: Tensor, # attn_mask: Optional[Tensor] = None, # dropout_p: float = 0.0, B, Nt, E = query.shape # torch.Size([8, 900, 32]), mhaf_key and mhaf_value is same shape. query = query / math.sqrt(E) # (B, Nt, E) x (B, E, Ns) -> (B, Nt, Ns) attn = torch.bmm(query, key.transpose(-2, -1)) # [8, 900, 32] @ [8, 32, 900] -> [8, 900, 900] # if mhaf_attn_mask is not None: # attn += mhaf_attn_mask attn = F.softmax(attn, dim=-1) if dropout > 0.0: attn = F.dropout(attn, p=dropout) # (B, Nt, Ns) x (B, Ns, E) -> (B, Nt, E) output = torch.bmm(attn, value) # [8, 900, 900] @ [8, 900, 32] -> # torch.Size([8, 900, 32]) # return output, attn attn_output, attn_output_weights = output, attn # -------------_scaled_dot_product_attention end------------- # tgt_len: 900 # [8, 900, 32]->[900, 8, 32]->[900, 1, 256] attn_output = attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim) # torch.Size([900, 1, 256]) attn_output = F.linear(attn_output, out_proj_weight, out_proj_bias) # nn.Linear
out = attn_output# ------------------------------self.attn end------------------------------
# return mha_identity + self.dropout_layer(self.proj_drop(out))query = identity + dropout_layer(mha_proj_drop(out))# torch.Size([900, 1, 256]) + # torch.Size([900, 1, 256])
复制代码

2.2 CustomMSDeformableAttention

功能:


  • 利用可变形注意力机制将 encoder 模块输出的 bev_embed 融入 object_query,如下图所示;

  • 输出该层的 output,将其作为下一层 DetrTransformerDecoderLayer 的输入,并利用该层 output 生成该层对应的 reference_points。



解析:


#-------------------------CustomMSDeformableAttention init(in part)---------------------------------sampling_offsets = nn.Linear(ca_embed_dims, ca_num_heads * ca_num_levels * ca_num_points * 2).cuda()attention_weights = nn.Linear(ca_embed_dims, ca_num_heads * ca_num_levels * ca_num_points).cuda()value_proj = nn.Linear(ca_embed_dims, ca_embed_dims).cuda()output_proj = nn.Linear(ca_embed_dims, ca_embed_dims).cuda()#-------------------------CustomMSDeformableAttention init(in part)---------------------------------if value is None:    value = query
if identity is None: identity = queryif query_pos is not None: query = query + query_posif not self.batch_first: # change to (bs, num_query ,embed_dims) #query:torch.Size([1, 900, 256]) query = query.permute(1, 0, 2) #value(即bev_embed):torch.Size([1, 50*50, 256]) value = value.permute(1, 0, 2)
bs, num_query, _ = query.shapebs, num_value, _ = value.shapeassert (spatial_shapes[:, 0] * spatial_shapes[:, 1]).sum() == num_value
#value(即bev_embed):torch.Size([1, 50*50, 256])value = self.value_proj(value)if key_padding_mask is not None: value = value.masked_fill(key_padding_mask[..., None], 0.0)#value:torch.Size([1, 50*50, 8, 32]),为多头做准备value = value.view(bs, num_value, self.num_heads, -1)
sampling_offsets = self.sampling_offsets(query).view( bs, num_query, self.num_heads, self.num_levels, self.num_points, 2)# 1, 900, 8, 1, 4, 2attention_weights = self.attention_weights(query).view( bs, num_query, self.num_heads, self.num_levels * self.num_points)# 1, 900, 8, 4, attention_weights = attention_weights.softmax(-1)
#attention_weights:torch.Size([1, 900, 8, 1, 32])attention_weights = attention_weights.view(bs, num_query, self.num_heads, self.num_levels, self.num_points)#reference_points:torch.Size([1, 900, 1, 2]) if reference_points.shape[-1] == 2: offset_normalizer = torch.stack( [spatial_shapes[..., 1], spatial_shapes[..., 0]], -1) #sampling_locations:torch.Size([1, 900, 8, 1, 4, 2]) sampling_locations = reference_points[:, :, None, :, None, :] \ + sampling_offsets \ / offset_normalizer[None, None, None, :, None, :]elif reference_points.shape[-1] == 4: sampling_locations = reference_points[:, :, None, :, None, :2] \ + sampling_offsets / self.num_points \ * reference_points[:, :, None, :, None, 2:] \ * 0.5else: raise ValueError( f'Last dim of reference_points must be' f' 2 or 4, but get {reference_points.shape[-1]} instead.')if torch.cuda.is_available() and value.is_cuda:
# using fp16 deformable attention is unstable because it performs many sum operations if value.dtype == torch.float16: MultiScaleDeformableAttnFunction = MultiScaleDeformableAttnFunction_fp32 else: MultiScaleDeformableAttnFunction = MultiScaleDeformableAttnFunction_fp32 output = MultiScaleDeformableAttnFunction.apply( value, spatial_shapes, level_start_index, sampling_locations, attention_weights, self.im2col_step)else: #output:torch.Size([1, 900, 256]) #可变形注意力机制,利用query从value(bev_embed)中提取有用信息 output = multi_scale_deformable_attn_pytorch( value, spatial_shapes, sampling_locations, attention_weights) #output:torch.Size([1, 900, 256])output = self.output_proj(output)
if not self.batch_first: # (num_query, bs ,embed_dims) output = output.permute(1, 0, 2)
return self.dropout(output) + identity
复制代码

3 cls_branches&®_branches

功能:


  • 利用 decoder 输出的 bev_embed, inter_states(6 层输出的 outs), init_reference_out(由 query_pos 生成的初始 reference_points), inter_references_out(6 层输出的 reference_points)生成目标类别和 bboxes;

  • 生成包含 bev_embed、 all_cls_scores、all_bbox_preds 在内的 outs,其中 all_cls_scores、all_bbox_preds 用于计算 Loss、梯度回传;bev_embed 可用于 segmentation 等任务,进行 BEV 视角下的语义分割。


解析:


#以下变量的含义见《BEVFormer开源算法逐行解析(一):Encoder部分》bs, num_cam, _, _, _ = mlvl_feats[0].shapedtype = mlvl_feats[0].dtypeobject_query_embeds = self.query_embedding.weight.to(dtype)bev_queries = self.bev_embedding.weight.to(dtype)bev_mask = torch.zeros((bs, self.bev_h, self.bev_w),                        device=bev_queries.device).to(dtype)bev_pos = self.positional_encoding(bev_mask).to(dtype)
if only_bev: # only use encoder to obtain BEV features, TODO: refine the workaround return self.transformer.get_bev_features( mlvl_feats, bev_queries, self.bev_h, self.bev_w, grid_length=(self.real_h / self.bev_h, self.real_w / self.bev_w), bev_pos=bev_pos, img_metas=img_metas, prev_bev=prev_bev, )else: #outputs就是object_query_embeds、bev_pos、bev_queries、img_metas和mlvl_feats #输入encoder和decoder模块后的最终输出 #outputs:bev_embed, inter_states, init_reference_out, inter_references_out outputs = self.transformer( mlvl_feats, bev_queries, object_query_embeds, self.bev_h, self.bev_w, grid_length=(self.real_h / self.bev_h, self.real_w / self.bev_w), bev_pos=bev_pos, reg_branches=self.reg_branches if self.with_box_refine else None, # noqa:E501 cls_branches=self.cls_branches if self.as_two_stage else None, img_metas=img_metas, prev_bev=prev_bev)
bev_embed, hs, init_reference, inter_references = outputshs = hs.permute(0, 2, 1, 3)outputs_classes = []outputs_coords = []for lvl in range(hs.shape[0]): if lvl == 0: reference = init_reference else: reference = inter_references[lvl - 1] reference = inverse_sigmoid(reference) #outputs_class:torch.Size([1, 900, 10]) outputs_class = self.cls_branches[lvl](hs[lvl]) #tmp:torch.Size([1, 900, 10]) tmp = self.reg_branches[lvl](hs[lvl])
# TODO: check the shape of reference assert reference.shape[-1] == 3 tmp[..., 0:2] += reference[..., 0:2] tmp[..., 0:2] = tmp[..., 0:2].sigmoid() tmp[..., 4:5] += reference[..., 2:3] tmp[..., 4:5] = tmp[..., 4:5].sigmoid() #下面" *(self.pc_range[3] -self.pc_range[0]) + self.pc_range[0]", #是为了将目标bboxes中心点x、y、z坐标恢复到实际尺度 tmp[..., 0:1] = (tmp[..., 0:1] * (self.pc_range[3] - self.pc_range[0]) + self.pc_range[0]) tmp[..., 1:2] = (tmp[..., 1:2] * (self.pc_range[4] - self.pc_range[1]) + self.pc_range[1]) tmp[..., 4:5] = (tmp[..., 4:5] * (self.pc_range[5] - self.pc_range[2]) + self.pc_range[2])
# TODO: check if using sigmoid outputs_coord = tmp outputs_classes.append(outputs_class) outputs_coords.append(outputs_coord)#outputs_classes:torch.Size([6, 1, 900, 10])outputs_classes = torch.stack(outputs_classes)#outputs_coords:torch.Size([6, 1, 900, 10])outputs_coords = torch.stack(outputs_coords)
outs = { 'bev_embed': bev_embed, 'all_cls_scores': outputs_classes, 'all_bbox_preds': outputs_coords, 'enc_cls_scores': None, 'enc_bbox_preds': None,}
#outs输出后就可以和class_labels和bboxe_labels一起计算Loss,#然后反向传播梯度,更新模型中的可学习参数:#各个线性层、object_query_embeds、bev_queries、bev_pos等return outs
复制代码


深色代码部分生成的 tmp[0:2]和 tmp[4:5] 结构见下图,实质上就是"DetectionTransformerDecoder"中生成的 reference_points。


结语:

至此,BEVFormer 中的 Encoder 和 Decoder 部分的逐行代码解析就完成了,如果后续有需求也可以再出一期关于解析 Loss 计算的文档,这部分比较基础,有兴趣的同学也可以先结合源码自学。


用户头像

还未添加个人签名 2021-03-11 加入

还未添加个人简介

评论

发布
暂无评论
BEVFormer 开源算法逐行解析(二):Decoder 和 Det 部分_自动驾驶_地平线智能驾驶开发者_InfoQ写作社区