ONNX 输入batch修改

ops/2024/11/25 13:30:31/

ONNX 输入batch修改

导出的onnx模型分为静态和动态输入两种,但一般用户会在导出后进行onnxsim操作,导致某些非全卷积的模型修改batch失败,比如transformer类其中reshape的attr属性会固定,修改相当麻烦,需要从源头重新导出是最佳选择,本文讨论在没有源头的情况下,尽量完成修改batch的修改,并固化每个op的输入输出tensor为正确shape。

依赖

  • onnx
  • onnxsim
import sys
import onnx
from onnx import shape_inference, helper
from onnxsim import simplifyfilepath = sys.argv[1]
batch = int(sys.argv[2])
opset_version = int(sys.argv[3])model = onnx.load(filepath)# 创建一个新的模型图
new_graph = helper.make_graph(nodes=model.graph.node,                # 保留所有节点name="new_graph",                      # 保留原图名称inputs=model.graph.input,              # 保留输入outputs=model.graph.output,            # 保留输出initializer=model.graph.initializer,   # 保留初始化器value_info=None
)# 构造一个新的模型
if not any(opset.domain == "" for opset in model.opset_import):model.opset_import.append(helper.make_opsetid(domain="", version=opset_version))new_model = helper.make_model(new_graph, producer_name=model.producer_name, opset_imports=model.opset_import)
for idx, _input in enumerate(new_model.graph.input):_input.type.tensor_type.shape.dim[0].dim_value = batch
for idx, _output in enumerate(new_model.graph.output):_output.type.tensor_type.shape.dim[0].dim_value = batchmodel_sim, success = simplify(new_model)
if not success:print("simplify failed")onnx.save(new_model, filepath.replace(".onnx", f"_b{batch}.onnx"))
else:onnx.save(model_sim, filepath.replace(".onnx", f"_b{batch}.onnx"))

http://www.ppmy.cn/ops/136577.html

相关文章

Git命令使用与原理详解

1.仓库 # 在当前目录新建一个Git代码库 $ git init ​ # 新建一个目录,将其初始化为Git代码库 $ git init [project-name] ​ # 下载一个项目和它的整个代码历史 $ git clone [url]2.配置 # 显示当前的Git配置 $ git config --list ​ # 编辑Git配置文件 $ git co…

UE5 slate BlankProgram独立程序系列

源码版Engine\Source\Programs\中copy BlankProgram文件夹,重命名为ASlateLearning,修改所有文件命名及内部名称。 ASlateLearning.Target.cs // Copyright Epic Games, Inc. All Rights Reserved.using UnrealBuildTool; using System.Collections.Ge…

Java的包装类及其缓存机制

Java的包装类及其缓存机制 ​ Java 的包装类(Wrapper Classes)是为每种基本数据类型提供的对象表示。基本数据类型(如 int、double 等)是非对象类型,而包装类为它们提供了对应的对象版本,以便可以在需要对…

手撕一个阻塞队列

目录 手撕一个阻塞队列代码讲解 手撕一个阻塞队列 要手撕一个阻塞队列:就要实现一个普通的队列,加上阻塞,加上锁 代码 class MyBlockingQueue{private String[] elemsnull;private int tail0;private int head0;private int size0;private…

鸿蒙NEXT开发-Navigation组件导航

注意:博主有个鸿蒙专栏,里面从上到下有关于鸿蒙next的教学文档,大家感兴趣可以学习下 如果大家觉得博主文章写的好的话,可以点下关注,博主会一直更新鸿蒙next相关知识 专栏地址: https://blog.csdn.net/qq_56760790/…

OmniDiskSweeper :一款专为 macOS 设计的磁盘使用分析工具

OmniDiskSweeper 是一款专为 macOS 设计的磁盘使用分析工具,由 The Omni Group 开发。它的主要目的是帮助用户可视化磁盘上的文件和文件夹,并找出占用大量空间的文件,从而帮助用户释放磁盘空间。 OmniDiskSweeper 的特点包括: 简…

【反向迭代器】—— 我与C++的不解之缘(十七)

前言 ​ 在STL中的迭代器部分,之前只关注与正向迭代器,忽视了反向迭代器;现在来看一下反向迭代器到底是个什么东西,以及反向迭代器怎么实现,怎么为之前自己模拟实现的容器增加反向迭代器? 反向迭代器的使用…

windows vscode C++ 简明教程

windows vscode C++ 简明教程 1 安装mingw64 MinGW-w64(Minimalist GNU for Windows 64)是一个开源工具集,用于在 Windows 系统上编译和生成原生的 Windows 应用程序。它是 MinGW 项目的扩展版本,支持 32 位和 64 位 Windows 程序开发。MinGW-w64 提供了 Windows 版的 GN…