ONNX 转 TensorRT Bug 记录:IIfConditionalOutputLayer

embedded/2024/12/28 22:07:15/

1. 问题描述

环境:TensorRT-8.6.1.6、CUDA-11.8

报错:Error[4]: /If_OutputLayer: IIfConditionalOutputLayer inputs must have the same shape. Shapes are [-1,384] and [-1,1,384].

复现代码:

python">import os
import torch
import torch.nn as nn
import torch.nn.functional as Fclass TestModel(torch.nn.Module):def __init__(self, mode):super().__init__()self.mode = modeself.conv = nn.Conv1d(512, 512, 3, 2, 1)def forward(self, x, mask):if self.mode == 1:return self.forward1(x, mask)elif self.mode == 2:return self.forward2(x, mask)elif self.mode == 3:return self.forward3(x, mask)elif self.mode == 4:return self.forward4(x, mask)else:raise ValueError("Invalid mode")def forward1(self, x, mask):mask = mask.unsqueeze(1)x = self.conv(x)mask = F.interpolate(mask, size=x.size(-1), mode="nearest")x = x * maskmask = mask.squeeze(1)return x, maskdef forward2(self, x, mask):mask = mask.unsqueeze(1)x = self.conv(x)mask = F.interpolate(mask, size=384, mode="nearest")x = x * maskmask = mask.squeeze(1)return x, maskdef forward3(self, x, mask):mask = mask.unsqueeze(1)x = x * maskmask = mask.squeeze(1)return x, maskdef forward4(self, x, mask):mask = mask.unsqueeze(1)x = self.conv(x)mask = F.interpolate(mask, size=x.size(-1), mode="nearest")x = x * maskb = x.shape[0]mask = mask.reshape(b, -1)return x, maskfake_input = torch.randn(1, 512, 768)
fake_mask = torch.randn(1, 768)
model1 = TestModel(1)
model2 = TestModel(2)
model3 = TestModel(3)
model4 = TestModel(4)with torch.no_grad():print([x.shape for x in model1(fake_input, fake_mask)])print([x.shape for x in model2(fake_input, fake_mask)])print([x.shape for x in model3(fake_input, fake_mask)])print([x.shape for x in model4(fake_input, fake_mask)])dynamic = {"x_input": {0: "batch"},"masks_input": {0: "batch"},"x_output": {0: "batch"},"masks_output": {0: "batch"}
}save_dir = "log"
os.makedirs(save_dir, exist_ok=True)for i, model in enumerate((model1, model2, model3, model4)):torch.onnx.export(model.cpu().eval(),(fake_input.cpu(), fake_mask.cpu()),os.path.join(save_dir, f'dynamic_{i+1}.onnx'),verbose=True,opset_version=17,do_constant_folding=True,input_names=["x_input", "masks_input"],output_names=["x_output", "masks_output"],dynamic_axes=dynamic)torch.onnx.export(model.cpu().eval(),(fake_input.cpu(), fake_mask.cpu()),os.path.join(save_dir, f'static_{i+1}.onnx'),verbose=True,opset_version=17,do_constant_folding=True,input_names=["x_input", "masks_input"],output_names=["x_output", "masks_output"],dynamic_axes=None)
exec="/home/sfy/SFY/camera/TensorRT-8.6.1.6/bin/trtexec"
dir="log"
rm "$dir"/*.txt "$dir"/*.plan: <<'EOF'
# onnx-simplifier 移除冗余节点
for file in "$dir"/*.onnx; doif [[ ! -f "$file" ]]; thenecho "No .onnx files found in $dir."continuefifn=$(basename "$file")python -m onnxsim "$file" "$dir/${fn%.onnx}_simp.onnx"
done
EOFfor file in "$dir"/*.onnx; doif [[ ! -f "$file" ]]; thenecho "No .onnx files found in $dir."continuefifn=$(basename "$file")if [[ $fn == dynamic* ]]; then$exec \--onnx="$file" \--saveEngine="$dir/${fn%.onnx}.plan" \--minShapes=x_input:1x512x768,masks_input:1x768 \--optShapes=x_input:2x512x768,masks_input:2x768 \--maxShapes=x_input:4x512x768,masks_input:4x768 \--verbose \> "$dir/${fn%.onnx}.txt" 2>&1elif [[ $fn == static* ]]; then$exec \--onnx="$file" \--saveEngine="$dir/${fn%.onnx}.plan" \--verbose \> "$dir/${fn%.onnx}.txt" 2>&1elsecontinuefi
done

  上述代码对源代码做了简化,仅剥离出造成问题的部分,用于复现和测试 Bug,输入可以看作 1 个 Batch 有 768 张图像序列,每个图像用 512 维特征向量表示;masks 原本是 bool 类型代表 768 张图像是真实数据还是填充数据。

  经过测试发现 Bug 由动态模式、F.interpolatesqueeze 组合引发,因此代码推理阶段分为以下 4 种模式:
(1)F.interpolate(mask, size=x.size(-1), mode="nearest") + mask.squeeze(1) 报错
(2)F.interpolate(mask, size=384, mode="nearest") + mask.squeeze(1) 动态模式报错,静态模式通过
(3)mask.squeeze(1) 通过
(4)F.interpolate(mask, size=x.size(-1), mode="nearest") + mask.reshape(b, -1)mask.view(b, -1) 通过

2. 解决方法

  用 reshapeview 代替 squeeze

3. 原因分析

  总结:squeeze 操作需要判断对应维度是否等于 1,而 F.interpolate 改变了张量形状、动态模式引入了维度数值的不确定性,这些使得 onnx 无法确定该维度是否等于 1(即使看上去可以推导出数值),导致添加了 If 节点,而 If 节点在不同状态下分别执行 SqueezeIdentity 导致输出形状不统一。采用 Reshape 操作可以直接规避此问题。

  下面是不同模式 ONNX 的结构图对比,以及分析细节。

mode1

在这里插入图片描述在这里插入图片描述
  mode1 在插值时使用 x.size(-1),在动态模式下需要通过 Shape 等一系列节点来获取维度信息;在静态模式下 x 所有初始维度固定,把 x.size(-1) 看作固定常量。

mode2

在这里插入图片描述在这里插入图片描述
  动态模式下 masks 经过 Resize 维度信息是不确定的,导致需要判断 masks.shape[1] == 1 才能执行 masks.squeeze(1),问题在于当不等于 1 时执行的是 Identity 导致输出形状不一致。

  比较令人费解的是对比 mode1 和 mode2 静态模式,可以发现 Resize 之前的结构完全相同,但是 mode1 在 Resize 之后仍引入了 If 节点导致异常。查到的解释是 squeeze 操作比较保守,mode1 原本 Resize 的维度是动态的依赖 x.size(-1),即使通过推导将动态转变为静态,但仍保留了动态逻辑(If 节点)。
  启用脚本中 onnx-simplifier 移除冗余节点的部分可以去除 mode1 静态模式中的 If 节点,但此方法对动态模式无效。

mode3

在这里插入图片描述在这里插入图片描述
  masks 维度的动态(不确定性)由 Resize 引入。

mode4

在这里插入图片描述在这里插入图片描述
  不使用 squeeze 便不需要判断维度是否符合要求,直接规避此问题。


http://www.ppmy.cn/embedded/149567.html

相关文章

流架构的读书笔记(2)

流架构的读书笔记&#xff08;2&#xff09; 一、建模工具之一沃德利地图 推测技术的发展,交流和辩论思想的最有力的方法是沃德利地图 沃德利地图的制作步骤 1确定范围和用户需求 2确定满足用户需求所需的组件 3在一条范围从全新到被人们接受的演进轴上评估这些组成 部分的演…

【Spring核心思想】IoC容器与依赖倒置(DI)

在日常开发中&#xff0c;我们总会面临一个问题&#xff1a;如何优雅地管理对象的创建和依赖&#xff1f; 你可能会写一堆代码来手动构造对象&#xff0c;但这种方式繁琐且难以维护。而当项目变得复杂&#xff0c;依赖链拉长&#xff0c;手动管理对象的方式很快就会捉襟见肘。 …

项目文档-代码检查报告

在项目验收阶段会需要很多项目报告&#xff0c;这里记录一下代码检查报告的整理方式。 代码检查报告&#xff1a;项目需要记录每一个改动的代码类&#xff0c;并记录检查结果和修改情况&#xff0c;以及检查人员。 检查结果&#xff1a;是否检查通过 修改情况&#xff1a;暂无需…

用VBA自动更正错误的注释引用序号

将扫描pdf文件进行文字识别时&#xff0c;对带圈数字表示的注释引用和注释序号往往会将数字序号认错。例如下面的文件&#xff1a; 这个文件的段落十分有规律&#xff1a;每首诗的标题样式为标题3&#xff0c;标题下面的段落为诗的正文&#xff0c;下面有一个样式为标题4的段落…

怎么配置每一次重启服务器后,自动启动Tocmat

前言 宝子们&#xff0c;今天来给大家详细讲讲服务器如何配置每次重启后自动启动 Tomcat&#xff0c;让你的服务器应用始终保持在线状态&#xff0c;高效运行&#xff01; windows版本 在 Windows 系统下&#xff0c;有两种常用的方法可以实现这个目标。 第一种方法是利用服…

SQL 实战:字符串处理函数 – 数据清洗与文本格式化

在数据分析和开发过程中&#xff0c;原始数据往往存在格式不统一、冗余字符等问题&#xff0c;直接影响查询和展示效果。SQL 提供了一系列强大的字符串处理函数&#xff0c;能够帮助开发者进行数据清洗和文本格式化操作&#xff0c;提高数据质量和查询效率。本文将通过多个实战…

Linux下Java通过JNI调用C++

以下为Demo流程 1.创建Java文件 public class HelloWord {// 声明本地方法public native void sayHello();static {// 加载本地库System.loadLibrary("hello");}public static void main(String[] args) {new HelloWord().sayHello();} } 2.编译生成.h头文件 在H…

5.npm包

文章目录 [TOC](文章目录) 3.npm与包3.1.包3.2.npm体验在项目中安装包的命令包管理配置文件一次性安装开发项目时安装的包如何从项目中卸载包devDependencies节点的作用解决下载包速度比较慢的问题nrm工具&#xff0c;利用其提供的终端命令&#xff0c;可以快速查看和切换下包的…