ONNX 输入batch修改
导出的onnx模型分为静态和动态输入两种,但一般用户会在导出后进行onnxsim操作,导致某些非全卷积的模型修改batch失败,比如transformer类其中reshape的attr属性会固定,修改相当麻烦,需要从源头重新导出是最佳选择,本文讨论在没有源头的情况下,尽量完成修改batch的修改,并固化每个op的输入输出tensor为正确shape。
依赖
- onnx
- onnxsim
import sys
import onnx
from onnx import shape_inference, helper
from onnxsim import simplifyfilepath = sys.argv[1]
batch = int(sys.argv[2])
opset_version = int(sys.argv[3])model = onnx.load(filepath)# 创建一个新的模型图
new_graph = helper.make_graph(nodes=model.graph.node, # 保留所有节点name="new_graph", # 保留原图名称inputs=model.graph.input, # 保留输入outputs=model.graph.output, # 保留输出initializer=model.graph.initializer, # 保留初始化器value_info=None
)# 构造一个新的模型
if not any(opset.domain == "" for opset in model.opset_import):model.opset_import.append(helper.make_opsetid(domain="", version=opset_version))new_model = helper.make_model(new_graph, producer_name=model.producer_name, opset_imports=model.opset_import)
for idx, _input in enumerate(new_model.graph.input):_input.type.tensor_type.shape.dim[0].dim_value = batch
for idx, _output in enumerate(new_model.graph.output):_output.type.tensor_type.shape.dim[0].dim_value = batchmodel_sim, success = simplify(new_model)
if not success:print("simplify failed")onnx.save(new_model, filepath.replace(".onnx", f"_b{batch}.onnx"))
else:onnx.save(model_sim, filepath.replace(".onnx", f"_b{batch}.onnx"))