pytorch导出onnx时遇到不支持的算子怎么解决

news/2024/11/15 0:53:04/

在使用pytorch模型训练完成之后,我们现在使用的比较多的一种方法是将pytorch模型转成onnx格式的模型中间文件,然后再根据使用的硬件来生成具体硬件使用的深度学习模型,比如TensorRT。
在从pytorch模型转为onnx时,我们可能会遇到部分算子无法转换的问题,本篇注意记录下解决方法。

在导出onnx时,如果出现报错的算子,可以先在下面的链接中查找onnx算子是否支持
https://github.com/onnx/onnx/blob/main/docs/Operators.md

pytorch中有,onnx中也有的算子

导出时使用的onnx op 版本低导致

这个就好解决了,把op库的版本提高就行,但是有可能提高了版本以后,又出现了原来支持的算子现在又不支持了,这个再说

pytorch中没有注册某个onnx算子

如果是这种情况,就按照下面的方式进行:

from torch.onnx import register_custom_op_symbolic
# 创建一个asinh算子的symblic,符号函数,用来登记
# 符号函数内部调用g.op, 为onnx计算图添加Asinh算子
#   g: 就是graph,计算图
#   也就是说,在计算图中添加onnx算子
#   由于我们已经知道Asinh在onnx是有实现的,所以我们只要在g.op调用这个op的名字就好了
#   symblic的参数需要与Pytorch的asinh接口函数的参数对齐
#       def asinh(input: Tensor, *, out: Optional[Tensor]=None) -> Tensor: ...
def asinh_symbolic(g, input, *, out=None):return g.op("Asinh", input)# 在这里,将asinh_symbolic这个符号函数,与PyTorch的asinh算子绑定。也就是所谓的“注册算子”
# asinh是在名为aten的一个c++命名空间下进行实现的# aten是"a Tensor Library"的缩写,是一个实现张量运算的C++库
register_custom_op_symbolic('aten::asinh', asinh_symbolic, 12)

另外一个写法
这个是类似于torch/onnx/symbolic_opset*.py中的写法
通过torch._internal中的registration来注册这个算子,让这个算子可以与底层C++实现的aten::asinh绑定
一般如果这么写的话,其实可以把这个算子直接加入到torch/onnx/symbolic_opset*.py中

import functools
from torch.onnx import register_custom_op_symbolic
from torch.onnx._internal import registration_onnx_symbolic = functools.partial(registration.onnx_symbolic, opset=9)@_onnx_symbolic('aten::asinh')
def asinh_symbolic(g, input, *, out=None):return g.op("Asinh", input)

pytorch中有,onnx中无的算子

继承torch.autograd.Function实现自定义算子

import torch
import torch.onnx
import onnxruntime
from torch.onnx import register_custom_op_symbolicOperatorExportTypes = torch._C._onnx.OperatorExportTypesclass CustomOp(torch.autograd.Function):@staticmethod def symbolic(g: torch.Graph, x: torch.Value) -> torch.Value:return g.op("custom_domain::customOp2", x)@staticmethoddef forward(ctx, x: torch.Tensor) -> torch.Tensor:ctx.save_for_backward(x)x = x.clamp(min=0)return x / (1 + torch.exp(-x))customOp = CustomOp.apply

然后再自己实现custom_domain::customOp2这个算子,如果用TensorRT,就需要自己实现一个插件。


http://www.ppmy.cn/news/1298931.html

相关文章

12.1SPI驱动框架

SPI硬件基础 总线拓扑结构 引脚含义 DO(MOSI):Master Output, Slave Input, SPI主控用来发出数据,SPI从设备用来接收数据 DI(MISO) :Master Input, Slave Output, SPI主控用来发出数据,SPI从设备用来接收…

基于JavaWeb+BS架构+SpringBoot+Vue校园一卡通系统的设计和实现

基于JavaWebBS架构SpringBootVue校园一卡通系统的设计和实现 文末获取源码Lun文目录前言主要技术系统设计功能截图订阅经典源码专栏Java项目精品实战案例《500套》 源码获取 文末获取源码 Lun文目录 第一章 概述 4 1.1 研究背景 4 1.2研究目的及意义 4 1.3国内外发展现状 4 1…

记录尝试投向不同的岗位——信息化专员——感想

1.保持随时响应的铃声 因为手机开启了远离手机的模式,然后会自动的把手机开启勿扰模式,导致对方打电话过来两次手机都没有响铃,本来就与岗位的匹配度低,然后没接到电话,这样连约面试的机会都没有。 人事提问 1.做过o…

CMU15-445-Spring-2023-Project #2 - B+Tree

前置知识:参考上一篇博文 CMU15-445-Spring-2023-Project #2 - 前置知识(lec07-010) CHECKPOINT #1 Task #1 - BTree Pages 实现三个page class来存储B树的数据。 BTree Page internal page和leaf page继承的基类,只包含两个…

MQTT服务器连接不上的问题

问题描述 环境:阿里云服务器Ubuntu 22.04.3 LTS,安装mosquitto后,在虚拟机端订阅消息出现报错(以前用阿里云Ubuntu20.04 LTS的服务器装上就能用),以下服务器ip是我乱填的 mosquitto_sub -t /iotstuff -h …

npm安装vue,添加淘宝镜像

如果是第一次使用命令栏可能会遇到权限问题。 解决vscode无法运行npm和node.js命令的问题-CSDN博客 安装 在vscode上面的导航栏选择terminal打开新的命令栏 另外可能会遇到网络或者其他的问题,可以添加淘宝镜像 npm install -g cnpm --registryhttps://registry.…

Pytorch从零开始实战16

Pytorch从零开始实战——ResNeXt-50算法的思考 本系列来源于365天深度学习训练营 原作者K同学 对于上次ResNeXt-50算法,我们同样有基于TensorFlow的实现。具体代码如下。 引入头文件 import numpy as np from tensorflow.keras.preprocessing.image import Ima…

thinkphp6入门(15)-- 模型动态构建查询条件

背景 我使用thinkphp6的模型写数据库查询,有多个where条件,但是不确定是否需要添加某个where条件,怎么才能动态得生成查询 链式查询 在ThinkPHP 6中,可以使用链式查询方法来动态地构建查询条件。可以根据参数的值来决定是否添加…