from typing import Sequence, Tuple
import horizon_plugin_pytorch.nn as hnn
import torch
import torch.nn as nn
from horizon_plugin_pytorch.quantization import QuantStub
from torch.quantization import DeQuantStub
# 基础模块的代码,可见地平线提供的OE docker
# /usr/local/lib/python3.10/dist-packages/hat/models/base_modules/basic_henet_module.py
from basic_henet_module import (
BasicHENetStageBlock, # HENet 的基本阶段块
S2DDown, # 降采样(downsampling)模块
)
from basic_henet_module import ConvModule2d # 2D 卷积层模块
# 继承 torch.nn.Module,定义神经网络的标准方式
class HENet(nn.Module):
"""
Module of HENet.
Args:
in_channels: The in_channels for the block.
block_nums: Number of blocks in each stage.
embed_dims: Output channels in each stage.
attention_block_num: Number of attention blocks in each stage.
mlp_ratios: Mlp expand ratios in each stage.
mlp_ratio_attn: Mlp expand ratio in attention blocks.
act_layer: activation layers type.
use_layer_scale: Use a learnable scale factor in the residual branch.
layer_scale_init_value: Init value of the learnable scale factor.
num_classes: Number of classes for a Classifier.
include_top: Whether to include output layer.
flat_output: Whether to view the output tensor.
extra_act: Use extra activation layers in each stage.
final_expand_channel: Channel expansion before pooling.
feature_mix_channel: Channel expansion is performed before head.
block_cls: Basic block types in each stage.
down_cls: Downsample block types in each stage.
patch_embed: Stem conv style in the very beginning.
stage_out_norm: Add a norm layer to stage outputs.
Ignored if include_top is True.
"""
def __init__(
self,
in_channels: int, # 输入图像的通道数(常见图像为 3)
block_nums: Tuple[int], # 每个阶段(Stage)的基础块(Block)数量
embed_dims: Tuple[int], # 每个阶段的特征通道数
attention_block_num: Tuple[int], # 每个阶段的注意力块(Attention Block)数量
mlp_ratios: Tuple[int] = (2, 2, 2, 2), # 多层感知机(MLP)扩展比率
mlp_ratio_attn: int = 2,
act_layer: Tuple[str] = ("nn.GELU", "nn.GELU", "nn.GELU", "nn.GELU"), # 激活函数类型
use_layer_scale: Tuple[bool] = (True, True, True, True),
layer_scale_init_value: float = 1e-5,
num_classes: int = 1000,
include_top: bool = True, # 是否包含最终的分类头(通常为 nn.Linear)
flat_output: bool = True,
extra_act: Tuple[bool] = (False, False, False, False),
final_expand_channel: int = 0,
feature_mix_channel: int = 0,
block_cls: Tuple[str] = ("DWCB", "DWCB", "DWCB", "DWCB"),
down_cls: Tuple[str] = ("S2DDown", "S2DDown", "S2DDown", "None"),
patch_embed: str = "origin", # 图像预处理方式(卷积 embedding)
stage_out_norm: bool = True, # 是否在阶段输出后加一层 BatchNorm,建议不要
):
super().__init__()
self.final_expand_channel = final_expand_channel
self.feature_mix_channel = feature_mix_channel
self.stage_out_norm = stage_out_norm
self.block_cls = block_cls
self.include_top = include_top
self.flat_output = flat_output
if self.include_top:
self.num_classes = num_classes
# patch_embed 负责将输入图像转换为特征
# 里面有两个convModule2d,进行了两次 3×3 的卷积(步长 stride=2),相当于 对输入图像进行 4 倍降采样
if patch_embed in ["origin"]:
self.patch_embed = nn.Sequential(
ConvModule2d(
in_channels,
embed_dims[0] // 2,
kernel_size=3,
stride=2,
padding=1,
norm_layer=nn.BatchNorm2d(embed_dims[0] // 2),
act_layer=nn.ReLU(),
),
ConvModule2d(
embed_dims[0] // 2,
embed_dims[0],
kernel_size=3,
stride=2,
padding=1,
norm_layer=nn.BatchNorm2d(embed_dims[0]),
act_layer=nn.ReLU(),
),
)
stages = [] # 构建多个阶段 (Stages),存放多个 BasicHENetStageBlock,每个block处理不同通道数的特征。
downsample_block = [] # 存放 S2DDown,在每个阶段之间进行降采样。
for block_idx, block_num in enumerate(block_nums):
stages.append(
BasicHENetStageBlock(
in_dim=embed_dims[block_idx],
block_num=block_num,
attention_block_num=attention_block_num[block_idx],
mlp_ratio=mlp_ratios[block_idx],
mlp_ratio_attn=mlp_ratio_attn,
act_layer=act_layer[block_idx],
use_layer_scale=use_layer_scale[block_idx],
layer_scale_init_value=layer_scale_init_value,
extra_act=extra_act[block_idx],
block_cls=block_cls[block_idx],
)
)
if block_idx < len(block_nums) - 1:
assert eval(down_cls[block_idx]) in [S2DDown], down_cls[
block_idx
]
downsample_block.append(
eval(down_cls[block_idx])(
patch_size=2,
in_dim=embed_dims[block_idx],
out_dim=embed_dims[block_idx + 1],
)
)
self.stages = nn.ModuleList(stages)
self.downsample_block = nn.ModuleList(downsample_block)
if final_expand_channel in [0, None]:
self.final_expand_layer = nn.Identity()
self.norm = nn.BatchNorm2d(embed_dims[-1])
last_channels = embed_dims[-1]
else:
self.final_expand_layer = ConvModule2d(
embed_dims[-1],
final_expand_channel,
kernel_size=1,
bias=False,
norm_layer=nn.BatchNorm2d(final_expand_channel),
act_layer=eval(act_layer[-1])(),
)
last_channels = final_expand_channel
if feature_mix_channel in [0, None]:
self.feature_mix_layer = nn.Identity()
else:
self.feature_mix_layer = ConvModule2d(
last_channels,
feature_mix_channel,
kernel_size=1,
bias=False,
norm_layer=nn.BatchNorm2d(feature_mix_channel),
act_layer=eval(act_layer[-1])(),
)
last_channels = feature_mix_channel
# 分类头
if self.include_top:
self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) # 将特征图变为 1×1
self.head = (
nn.Linear(last_channels, num_classes)
if num_classes > 0
else nn.Identity()
)
else:
stage_norm = []
for embed_dim in embed_dims:
if self.stage_out_norm is True:
stage_norm.append(nn.BatchNorm2d(embed_dim))
else:
stage_norm.append(nn.Identity())
self.stage_norm = nn.ModuleList(stage_norm)
self.up = hnn.Interpolate(
scale_factor=2, mode="bilinear", recompute_scale_factor=True
)
self.quant = QuantStub()
self.dequant = DeQuantStub()
def forward(self, x):
x = self.quant(x)
if isinstance(x, Sequence) and len(x) == 1:
x = x[0]
# 依次经过 patch_embed、stages、downsample_block 处理特征图。
x = self.patch_embed(x)
outs = []
for idx in range(len(self.stages)):
x = self.stages[idx](x)
if not self.include_top:
x_normed = self.stage_norm[idx](x)
if idx == 0:
outs.append(self.up(x_normed))
outs.append(x_normed)
if idx < len(self.stages) - 1:
x = self.downsample_block[idx](x)
if not self.include_top:
return outs
if self.final_expand_channel in [0, None]:
x = self.norm(x)
else:
x = self.final_expand_layer(x)
x = self.avgpool(x)
x = self.feature_mix_layer(x)
x = self.head(torch.flatten(x, 1))
x = self.dequant(x)
if self.flat_output:
x = x.view(-1, self.num_classes)
return x
# ---------------------- TinyM ----------------------
depth = [4, 3, 8, 6]
block_cls = ["GroupDWCB", "GroupDWCB", "AltDWCB", "DWCB"]
width = [64, 128, 192, 384]
attention_block_num = [0, 0, 0, 0]
mlp_ratios, mlp_ratio_attn = [2, 2, 2, 3], 2
act_layer = ["nn.GELU", "nn.GELU", "nn.GELU", "nn.GELU"]
use_layer_scale = [True, True, True, True]
extra_act = [False, False, False, False]
final_expand_channel, feature_mix_channel = 0, 1024
down_cls = ["S2DDown", "S2DDown", "S2DDown", "None"]
patch_embed = "origin"
stage_out_norm = False
# 初始化 HENet 模型
model = HENet(
in_channels=3, # 假设输入是 RGB 图像
block_nums=tuple(depth),
embed_dims=tuple(width),
attention_block_num=tuple(attention_block_num),
mlp_ratios=tuple(mlp_ratios),
mlp_ratio_attn=mlp_ratio_attn,
act_layer=tuple(act_layer),
use_layer_scale=tuple(use_layer_scale),
extra_act=tuple(extra_act),
final_expand_channel=final_expand_channel,
feature_mix_channel=feature_mix_channel,
block_cls=tuple(block_cls),
down_cls=tuple(down_cls),
patch_embed=patch_embed,
stage_out_norm=stage_out_norm,
num_classes=1000, # 假设用于 ImageNet 1000 类分类
include_top=True,
)
# ---------------------- 处理单帧输入数据 ----------------------
# 生成一个随机图像张量,假设输入是 224x224 RGB 图像
input_tensor = torch.randn(1, 3, 224, 224) # [batch, channels, height, width]
# ---------------------- 进行推理 ----------------------
model.eval()
with torch.no_grad(): # 关闭梯度计算,提高推理速度
output = model(input_tensor)
# ---------------------- 输出结果 ----------------------
print("模型输出形状:", output.shape)
print("模型输出类型:", type(output))
print("模型输出长度:", len(output))
print(output)
print("预测类别索引:", torch.argmax(output, dim=1).item()) # 获取最大概率的类别索引
# 输出 FLOPs 和 参数量
from thop import profile
flops, params = profile(model, inputs=(input_tensor,))
print(f"FLOPs: {flops / 1e6:.2f}M") # 以百万次运算(MFLOPs)显示
print(f"Params: {params / 1e6:.2f}M") # 以百万参数(M)显示
评论