掘金 人工智能 08月20日
征程 6 | PTQ 精度调优辅助代码,总有你用得上的
index_new5.html
../../../zaker_core/zaker_tpl_static/wap/tpl_guoji1.html

 

本文档提供了一系列ONNX模型处理和精度调优的实用技巧。首先介绍了如何使用`onnx.utils.extract_model`截取模型片段,便于调试和分析。接着,展示了如何通过Python API快速获取模型中所有算子(op type)的列表,以辅助精调。随后,详细阐述了如何手动编写代码,利用余弦相似度对比浮点模型与量化模型的推理结果,以评估精度损失。最后,提供了将精度调试日志完整保存到文件的方法,解决了终端输出截断的问题,便于深入分析和定位量化风险。这些方法旨在帮助开发者更高效地处理和优化ONNX模型。

📝 **模型片段截取**:利用`onnx.utils.extract_model`函数,可以从完整的ONNX模型中提取出特定的子模型或子图。这一功能对于隔离问题、加速推理或制作简化模型进行校准(如量化前)至关重要,通过指定输入和输出节点,ONNX库会自动构建一个逻辑连贯的新模型文件,但需注意某些节点可能不支持此操作。

💡 **算子类型识别**:为了进行精细的精度调优(如将特定算子设置为INT16或FP32),需要快速了解模型中包含的所有算子类型。通过加载优化后的ONNX模型并遍历其节点,可以收集所有`op_type`,这对于后续的YAML配置至关重要,但需注意可能存在遗漏,建议结合编译日志进行补充。

📊 **量化精度对比**:为了更全面地评估量化对模型精度的影响,可以手动编写Python脚本来计算量化前后模型输出的相似度。文中提供的代码示例展示了如何使用`HBRuntime`加载浮点模型和量化模型,输入相同的校准数据,然后计算它们输出结果之间的余弦相似度,从而量化精度损失情况。

🗄️ **调试日志保存**:在进行精度调试时,尤其当模型庞大、算子众多时,直接在终端执行命令可能导致输出信息被截断。推荐使用`tee`或`nohup`结合重定向的方式,将精度调试的详细日志保存到本地文件,这样可以确保所有信息都被完整记录,便于后续的分析和问题定位,提高调试效率。

一、截取 onnx 模型片段

在模型编译的时候,往往会出现各种各样的报错,您可能会受限于公司要求,无法把完整的 onnx 模型发送给地平线做分析,此时可以考虑截取 onnx 模型,找到可复现的片段,再将该片段提供给地平线的技术支持人员(这种方式通常是可以被公司允许的)。我们可以直接利用 onnx 的 python api 去完成这件事,onnx.utils.extract_model 的具体使用方式如下:

onnx.utils.extract_modelonnx 官方库提供的一个 从 ONNX 模型中提取子模型(子图) 的实用函数。 它的常见用途包括:

提取模型中某一部分用于调试或加速推理;

截取特定中间层的输出;

制作用于 calibration 的简化子模型(例如量化前的子图)。

1.1 基本用法

from onnx import utils utils.extract_model( input_path, # 原始模型文件路径(.onnx)output_path, # 提取后的子模型保存路径(.onnx)input_names, # 子模型的输入节点名列表(字符串output_names # 子模型的输出节点名列表(字符串))

ONNX 会自动计算以这些输入到输出为范围的所有依赖节点,然后构造出一个合法的新 ONNX 子模型。

举例来说,我们可以运行这样的代码截取 onnx 模型片段:

import onnxonnx.utils.extract_model('./heatmap_based_grasp_v2_op16_post_3.onnx','grasp_backbone.onnx',['input_name'],['output_name'])

请注意,这个脚本不一定能一次直接运行成功,有些节点不支持 extract,如果失败,可以尝试配置模型节点的 name 或者节点输入输出的 name,或者尝试使用其他节点或输入输出。

二、打印模型的所有 op type

当我们在做比较细致的精度调优时,会想把某一类算子全部配置 int16 或者 float32,而如何快速知道模型有哪些算子类别呢?首先我们要知道,PTQ 在校准的时候,回忆 optimized onnx 为处理对象,因此我们只要知道 optimized onnx 模型有哪些算子类别就可以了。用户使用 PTQ 时,可以使用我们提供好的功 python api 完成这个事情,举例如下:

from hmct.ir import load_model model = load_model('optimized.onnx')op_types = set()for node in model.graph.nodes:    op_types.add(node.op_type)print(op_types)

yaml 配置算子时,以 optimized 模型的 node name 为准,这里打印 op type 最好也选 optimized.onnx,但该功能可能会有部分 op type 漏打印的现象,可以检查 hb_compile 之后的 log 查漏补缺。

三、手动计算量化前后相似度

虽然模型编译的时候,日志里会提供相似度的情况,但显示的并不够完整。如果我们想知道所有输入数据或者某个特定输入数据的相似度情况,或者不同阶段模型的相似度情况,就需要手写代码去做模型推理。这里我提供这样一份参考代码,用来读取文件夹里所有数据,并挨个打印相似度,满足我们的相似度计算需求,从而更好地了解模型的精度情况。

import numpy as npfrom horizon_tc_ui.hb_runtime import HBRuntimedef cosine_similarity(arr1, arr2):    arr1 = np.array(arr1)    arr2 = np.array(arr2)    dot_product = np.dot(arr1, arr2)  # 计算点积    norm_arr1 = np.linalg.norm(arr1)  # 计算arr1的范数    norm_arr2 = np.linalg.norm(arr2)  # 计算arr2的范数    if norm_arr1 == 0 or norm_arr2 == 0:        return 0.0    similarity = dot_product / (norm_arr1 * norm_arr2)  # 计算余弦相似度    return similaritydef float_vs_quant(i):    try:        input_depth = np.fromfile(f"./calib_data_bin/input_depth/{i}_input_depth.bin", dtype=np.float32).reshape(1,1280,720)    except Exception as e:        return 0    xyzrgb_tensor = np.fromfile(f"./calib_data_bin/xyzrgb_tensor/{i}_xyzrgb_tensor.bin", dtype=np.float32).reshape(1,6,640,360)    sess = HBRuntime("./model_output/grasp_original_float_model.onnx")    input_names = sess.input_names    output_names = sess.output_names    input_feed = {input_names[0]: input_depth,                   input_names[1]: xyzrgb_tensor,                  }    output = sess.run(output_names, input_feed)    pred_6d_grasp_float = output[0]    print("float")    print(pred_6d_grasp_float[0])    sess = HBRuntime("./model_output/grasp_quantized_model.bc")    input_names = sess.input_names    output_names = sess.output_names    input_feed = {input_names[0]: input_depth,                   input_names[1]: xyzrgb_tensor,                  }    output = sess.run(output_names, input_feed)    pred_6d_grasp_quant = output[0]    print("quant")    print(pred_6d_grasp_quant[0])    return cosine_similarity(pred_6d_grasp_float.reshape(-1), pred_6d_grasp_quant.reshape(-1))if __name__ == "__main__":    for i in range(0, 100):        cos_sim = float_vs_quant(i)        if cos_sim == 0:            continue        print(str(i)+ " "+str(cos_sim)+"\n")

3.1 主函数:float_vs_quant(i)

这个函数是整个对比流程的核心:

读取输入数据
input_depth = np.fromfile(..., dtype=np.float32).reshape(1,1280,720) xyzrgb_tensor = np.fromfile(..., dtype=np.float32).reshape(1,6,640,360)

分别读取第 i 条样本的:

深度图:维度是 (1, 1280, 720)

RGB + XYZ 张量:维度是 (1, 6, 640, 360)

若文件读取失败(如不存在),直接返回 0(跳过该样本)。

推理浮点模型
sess = HBRuntime("./model_output/grasp_original_float_model.onnx") ... pred_6d_grasp_float = output[0]

使用 HBRuntime 加载浮点版模型 .onnx

输入数据为 input_depthxyzrgb_tensor

获取输出(预测的 6D 抓取向量)。

推理量化模型
sess = HBRuntime("./model_output/grasp_quantized_model.bc") ... pred_6d_grasp_quant = output[0]

加载量化后的模型(通常是 .bc 格式,地平线编译后模型)。

同样输入相同的数据。

获取量化模型的推理结果。

余弦相似度对比
return cosine_similarity(pred_6d_grasp_float.reshape(-1), pred_6d_grasp_quant.reshape(-1))

将浮点和量化预测结果展平成一维向量。

计算并返回它们的余弦相似度。

3.2 主执行入口

if name == "__main__":for i in range(0, 100): cos_sim = float_vs_quant(i) ...

对编号为 0 到 99 的输入样本,逐个调用 float_vs_quant

跳过无效样本(返回值为 0)。

打印样本编号及对应的相似度。

3.3 输出示例

float [0.12 0.45 0.87 ...] quant [0.11 0.44 0.85 ...] 3 0.99875 ...

展示浮点和量化模型的推理输出;

输出每个有效样本的相似度分数(越接近 1 表示差异越小)。

四、精度 debug 保存终端打印日志

通常来说,我们可以使用精度 DEBUG 功能,去查看哪些算子的量化风险高,从而为其设置更高的量化精度。但有时,我们的模型特别大,算子特别多,如果直接在 vscode 终端执行精度 debug 命令,打印的算子信息很可能不够完整,会被截断,因此可以使用下面介绍方法将精度 debug 日志完整地保存到本地文件里。

首先,我们先在 debug.py 脚本里写好我们要执行的命令,比如:

import loggingimport hmct.quantizer.debugger as dbg# 若verbose=True时,需要先设置log level为INFOlogging.getLogger().setLevel(logging.INFO)# 获取节点量化敏感度node_message = dbg.get_sensitivity_of_nodes(        model_or_file='../model_output_debug/graspnet_calibrated_model.onnx',        metrics=['cosine-similarity', 'mse'],        calibrated_data='../calibration_data/',        output_node=None,        node_type='activation',        data_num=None,        verbose=True,        interested_nodes=None)

方法 1、前台运行时保存,这个方法会持续占用该终端,直到程序结束。

python3 debug.py 2>&1 | tee debug_output.txt

方法 2、后台运行时保存,这是更为推荐的方法,这样我们在这个终端还可以同时做其他事,比如同时运行其他 node_type 的静的 debug 功能。

nohup python3 debug.py >debug_output.log 2>&1 &

在程序运行时,目录下就会生成对应的 log 日志,我们可以随时查看精度 debug 的进展,非常方便!

Fish AI Reader

Fish AI Reader

AI辅助创作,多种专业模板,深度分析,高质量内容生成。从观点提取到深度思考,FishAI为您提供全方位的创作支持。新版本引入自定义参数,让您的创作更加个性化和精准。

FishAI

FishAI

鱼阅,AI 时代的下一个智能信息助手,助你摆脱信息焦虑

联系邮箱 441953276@qq.com

相关标签

ONNX 模型优化 精度调优 量化 AI推理
相关文章