前言
构建onnx方式通常有两种:
1、通过代码转换成onnx结构,比如pytorch —> onnx
2、通过onnx 自定义结点,图,生成onnx结构
本文主要是简单学习和使用两种不同onnx结构,
下面以pow
结点进行分析
方式
方法一:pytorch --> onnx
固定shape
import torch
import torch.nn as nn class JustPow(nn.Module):def __init__(self):super(JustPow,self).__init__()def forward(self,x):x = torch.pow(x, 2)return xnet = JustPow()
model_name = 'JustPow.onnx'#保存ONNX的文件名字
dummy_input = torch.randn(4, 778, 1500)
torch.onnx.export(net, dummy_input, model_name, input_names=['input'], output_names=['output'])
结果如图所示:
动态shape
将第一维度设置为动态shape
# 只需要在这里对应位置修改即可
torch.onnx.export(net, dummy_input, model_name, input_names=['input'], output_names=['output'],dynamic_axes={'input': {0: 'batch_size'},'output': {0: 'batch_size'}})# 可以将得到的模型,进一步进行简化处理
onnxsim 方式
方法二: onnx
import onnx
from onnx import TensorProto, helper, numpy_helperdef run():print("run start....\n")# 待完成return modelif __name__ == "__main__":model = run()onnx.save(model, "./test_rpow.onnx")
运行onnx
import onnx
import onnxruntime
import numpy as np# 检查onnx计算图
def check_onnx(mdoel):onnx.checker.check_model(model)# print(onnx.helper.printable_graph(model.graph))def run(model):print(f'run start....\n')session = onnxruntime.InferenceSession(model,providers=['CPUExecutionProvider'])input_name1 = session.get_inputs()[0].name input_data1= np.random.randn(4,778,1500).astype(np.float32)print(f'input_data1 shape:{input_data1.shape}\n')output_name1 = session.get_outputs()[0].namepred_onx = session.run([output_name1], {input_name1: input_data1})[0]print(f'pred_onx shape:{pred_onx.shape} \n')print(f'run end....\n')if __name__ == '__main__':path = "./pow_dynamic_sim.onnx"model = onnx.load("./pow_dynamic_sim.onnx")check_onnx(model)run(path)