import torchimport torch.nn as nnimport torch.nn.functional as Ffrom horizon_plugin_pytorch import set_march, Marchset_march(March.NASH_M)from horizon_plugin_pytorch.quantization import prepare, set_fake_quantize, FakeQuantStatefrom horizon_plugin_pytorch.quantization import QuantStubfrom horizon_plugin_pytorch.quantization.hbdk4 import exportfrom horizon_plugin_pytorch.quantization.qconfig_template import calibration_8bit_weight_16bit_act_qconfig_setter, default_calibration_qconfig_setterfrom horizon_plugin_pytorch.quantization.qconfig import get_qconfig, MSEObserver, MinMaxObserverfrom horizon_plugin_pytorch.dtype import qint8, qint16from torch.quantization import DeQuantStubfrom hbdk4.compiler import statistics, save, load,visualize,compile,convert, hbm_perf
class SimpleConvNet(nn.Module): def __init__(self): super(SimpleConvNet, self).__init__() # 第一个节点:输入通道 1,输出通道 16,卷积核 3x3 self.conv1 = nn.Conv2d(in_channels=1, out_channels=16, kernel_size=3, stride=1, padding=1) # 后续添加一个池化层和一个全连接层 self.pool = nn.MaxPool2d(kernel_size=2, stride=2) self.fc = nn.Linear(16 * 14 * 14, 10) # 假设输入图像为 28x28 self.quant = QuantStub() self.dequant = DeQuantStub()
def forward(self, x): x = self.quant(x) x = self.conv1(x) # 卷积层 x = F.relu(x) # 激活 x = self.pool(x) # 池化 x = x.view(x.size(0), -1) # Flatten x = self.fc(x) # 全连接层输出 x = self.dequant(x) return x
# 构造模型model = SimpleConvNet()
# 构造一个假输入:batch_size=4,单通道,28x28 图像example_input = torch.randn(4, 1, 28, 28)output = model(example_input)
print("输出 shape:", output.shape) # torch.Size([4, 10])
calib_model = prepare(model.eval(), example_input, qconfig_setter=( default_calibration_qconfig_setter, ), )
calib_model.eval()set_fake_quantize(calib_model, FakeQuantState.CALIBRATION)calib_model(example_input)
calib_model.eval() set_fake_quantize(calib_model, FakeQuantState.VALIDATION)calib_out = calib_model(example_input)print("calib输出数据:", calib_out)
qat_bc = export(calib_model, example_input)mean = [0.485]std = [0.229]func = qat_bc[0]
for input in func.flatten_inputs[::-1]: split_inputs = input.insert_split(dim=0) for split_input in reversed(split_inputs): node = split_input.insert_transpose([0, 3, 1, 2]) node = node.insert_image_preprocess(mode="skip", divisor=255, mean=mean, std=std, is_signed=True) node.insert_image_convert(mode="gray")
quantized_bc = convert(qat_bc, "nash-m")hbir_func = quantized_bc.functions[0]hbir_func.remove_io_op(op_types = ["Dequantize","Quantize"])visualize(quantized_bc, "model_result/quantized_batch4.onnx")statistics(quantized_bc)params = {'jobs': 64, 'balance': 100, 'progress_bar': True, 'opt': 2,'debug': True, "advice": 0.0}hbm_path="model_result/batch4-gray.hbm"print("start to compile")compile(quantized_bc, march="nash-m", path=hbm_path, **params)print("end to compile")
评论