【技能---如何正确导出onnx】

news/2025/2/13 5:35:07/

提示:文章写完后,目录可以自动生成,如何生成可参考右边的帮助文档

文章目录

  • 前言
  • 注意事项
  • 案例说明
    • 1. 使用 int 转换来避免直接使用 tensor.size 的返回值:
    • 2. 使用 scale_factor 替代 size 参数:
    • 3. 将 -1 放在 view 操作的 batch 维度:
    • 4. 使用 dynamic_axes 参数指定动态轴:
    • 5. 使用 opset_version=11:
    • 6. 综合应用:
  • 总结


前言

提示:这里可以添加本文要记录的大概内容:


提示:以下是本篇文章正文内容,下面案例可供参考

注意事项

1.对于任何用到shape,size返回值的参数时,例如:tensor.view(tensor.size(0),-1)这类操作,避免直接使用tensor.size的返回值,而是加上int转换,tensor.view(int(tensor.size(0)),-1);

2.对于nn.Upsample或nn.functional.interpolate函数,使用scale_factor指定倍率,而不是使用size参数指定大小;

3.对于reshape、view操作时候,-1指定请放在batch维度。其他维度可以计算出来即可。batch维度禁止指定为大于-1的明确数字;

4.torch.onnx.export指定dynamic_axes参数,并且只指定batch维度。我们只需要动态batch,相对动态的宽高有其他方案;

5.使用opset_version=11,不要低于11;

6.掌握了这些,就可以保证后面各种情况的顺利了。

这些做法的必要性体现在,简化过程的复杂度,去掉gather、shape类的节点,很多时候,部分不这么改看似也是可以但是需求复杂后,依旧存在各类问题。按照说的这么做,基本可以。

案例说明

1. 使用 int 转换来避免直接使用 tensor.size 的返回值:

原始代码:

result = tensor.view(tensor.size(0), -1)

修改后的代码:

result = tensor.view(int(tensor.size(0)), -1)

解释: 使用 int 转换可以确保 tensor.size(0) 返回的结果是整数类型,避免在某些情况下可能导致的类型不匹配问题。

2. 使用 scale_factor 替代 size 参数:

原始代码:

upsampled_tensor = nn.functional.interpolate(input_tensor, size=(height, width))

修改后的代码:

upsampled_tensor = nn.functional.interpolate(input_tensor, scale_factor=(scale_height, scale_width))

解释: 使用 scale_factor 可以更直观地表示上采样的倍率,而不是直接指定目标大小。这样更易读且避免了手动计算目标大小的麻烦。

3. 将 -1 放在 view 操作的 batch 维度:

原始代码:

reshaped_tensor = tensor.view(batch_size, -1, height, width)

修改后的代码:

reshaped_tensor = tensor.view(batch_size, -1, height, width)

解释: 将 -1 放在 batch 维度可以更方便地根据张量的总大小自动计算其他维度的大小。这样可以避免手动计算其他维度的麻烦。

4. 使用 dynamic_axes 参数指定动态轴:

原始代码:

torch.onnx.export(net, dummy_input, onnx_path, dynamic_axes={'input': {0: 'batch_size', 1: 'sequence_length'}})

解释: 在 torch.onnx.export 中使用 dynamic_axes 参数,只指定 batch
维度,因为它是唯一需要动态变化的。其他轴,如序列长度,可以使用其他方案进行处理。

5. 使用 opset_version=11:

原始代码:

torch.onnx.export(net, dummy_input, onnx_path, opset_version=9)

修改后的代码:

torch.onnx.export(net, dummy_input, onnx_path, opset_version=11)

解释: 使用 opset_version=11 可以确保使用 ONNX 的较新功能和操作,提高模型的兼容性和性能。

6. 综合应用:

在综合应用上述建议时,可以考虑以下示例:

import torch
import torch.nn as nn# 示例模型
class ExampleModel(nn.Module):def __init__(self):super(ExampleModel, self).__init__()self.fc = nn.Linear(100, 50)def forward(self, x):return self.fc(x)# 示例输入
dummy_input = torch.randn(32, 100)# 导出模型到ONNX
onnx_path = "example_model.onnx"
net = ExampleModel()
torch.onnx.export(net,dummy_input,onnx_path,opset_version=11,input_names=['input'],output_names=['output']
)

上述示例中,我们使用了 int 转换,避免了手动计算大小,使用了 scale_factor 替代了 size 参数,将 -1 放在了 view 操作的 batch 维度,指定了 dynamic_axes 参数,并设置了 opset_version=11。这样可以确保导出的ONNX模型在简化过程中更加稳定和清晰。


总结

以上就是导出onnx需要注意的一些地方,不足之处,还请大家斧正!!1


http://www.ppmy.cn/news/1339005.html

相关文章

Uni-app 如何上传文件, 使用的API是什么

在uni-app中上传文件的方法有很多,其中一种常用的方法是使用wx.uploadFile() API。该API可以上传本地文件或网络文件,并支持设置请求头、请求参数等选项。 一.引入API import { uploadFile } from /util/request.js;二.使用API 上传文件 uploadFile({…

Django模型(七)

一、聚合与分组查询 1.1、准备数据 class Cook(models.Model):"""厨师"""name = models.CharField(max_length=32,verbose_name=厨师名)level = models.IntegerField(verbose_name=厨艺等级)age = models.IntegerField(verbose_name=年龄)sect …

山东省七五商贸有限公司缝纫设备采购项目(第一批次)

山东省七五商贸有限公司缝纫设备采购项目(第一批次) (招标编号:JCJS-2024-002) 项目所在地区:山东省 一、招标条件 本山东省七五商贸有限公司缝纫设备采购项目(第一批次)已由项目审批/核准/备案机关批准,项目资金来源为其他资金/,招标人为山东省七五商贸…

【JavaSE篇】——内部类

目录 🎓内部类 🎈内部类的分类 🚩实例内部类 一.如何实例内部类对象 二.实例内部类中为什么不能有静态成员变量 (用final解决) 三.在实例内部类对象时,如何访问外部类当中相同的成员变量?…

17. Spring Boot Actuator

17. Spring Boot Actuator Spring Boot执行器(Actuator)提供安全端点,用于监视和管理Spring Boot应用程序。 默认情况下,所有执行器端点都是安全的。 在本章中,将详细了解如何为应用程序启用Spring Boot执行器。 启用Spring Boot Actuator …

秋招面试—JS篇

2024 JavaScript面试题 1.new 操作符的工作原理 ①.创建一个新的空对象 ②.将这个对象的原型设置为函数的 prototype 对象 ③.让函数的this指向该对象,为函数添加属性和方法 ④.最后返回这个对象 2.什么是DOM,什么是BOM? DOM:文档对象…

化工企业能源在线监测管理系统,能源管理新利器

化工企业在开展化工生产活动时,能源消耗量较大,其节能潜力空间也较大,因此必须控制能耗强度,促进能效水平的稳步提升。化工企业通过能源现状的分析,能够实现能源使用情况的实时反馈与监管,从而达到节能减排…

机器学习复习(6)——numpy的数学操作

加减法运算 # 创建两个不同的数组 a np.arange(4) #list(0,1,2,3 b np.array([5,10,15,20]) # 两个数组做减法运算 b-a 运行结果: 计算数组的平方 #b*2代表数组b每个元素乘以2 #b**2代表数组b每个元素的2次方 b**2 运行结果: 计算数组的正弦值 #…