pytorch自定义算子导出onnx

devtools/2024/11/25 2:50:03/

文章目录

      • 1、为什么要自定义算子?
      • 2、如何自定义算子
      • 3、自定义算子导出onnx
      • 4、example
        • 1、重写一个pytorch 自定义算子(实现自定义激活函数)
        • 2、现有算子上封装pytorch 自定义算子(实现动态放大超分辨率模型)

1、为什么要自定义算子?

1、没有现成可用的算子,需要根据自己的接口重写。
2、现有的算子接口不兼容,需要在原有的算子上进行封装。

2、如何自定义算子

继承torch.autograd.Function类,实现其forward()backward()方法,就可以成为一个pytorch自定义算子。就可以在模型训练推理中完成前向推理和反向传播。
forward() 函数的第一个参数必须是ctx, 后面是输入。
在工程部署上,一般为了加快计算,自定义算子需要用cuda 实现forward()、backward()kernel 函数。

3、自定义算子导出onnx

实现其symbolic 静态方法,当我们调用torch.onnx.export()时,就可以导出onnx 算子。
symbolic是符号函数,通常在其内部返回一个g.op()对象。g.op() 把一个 PyTorch 算子映射成一个或多个 ONNX 算子,或者是自定义的 ONNX 算子。
symbolic函数的第一个参数必须是g, 后面是和forward()对应的输入。
g.op() 做算子映射,g.op 的参数:
1、第一个参数为算子名字
2、后面参数与forward() 输入对应
3、往后可以是一些算子自带常量和属性值。常量视为输入,属性值需要用 字段_s/i/f = 默认值表示。_s 表示字符串,_i 表示 int64, _f 表示 float32。常量用类似 g.op(“Constant”, value_t=torch.tensor([3, 2, 1], dtype=torch.float32))表示

4、example

pytorch__19">1、重写一个pytorch 自定义算子(实现自定义激活函数)

实现自己的激活函数MYSELU 算子。

python">import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.onnx
import torch.autograd#继承torch.autograd.Function
class MYSELUImpl(torch.autograd.Function): @staticmethoddef symbolic(g, x, p):return g.op("MYSELU", x, p,  # 表示onnx算子的名称为MYSELU,参数与forward()对应# 给算子传一个常数参数g.op("Constant", value_t=torch.tensor([3, 2, 1], dtype=torch.float32)),attr1_s="这是字符串属性", # s表示字符串attr2_i=[1, 2, 3], # i表示整数attr3_f=222  # f表示浮点数)@staticmethoddef forward(ctx, x, p): # 前行推理return x * 1 / (1 + torch.exp(-x))class MYSELU(nn.Module): def __init__(self, n):super().__init__()self.param = nn.parameter.Parameter(torch.arange(n).float())def forward(self, x):return MYSELUImpl.apply(x, self.param) #推理调用class Model(nn.Module):def __init__(self):super().__init__()self.conv = nn.Conv2d(1, 1, 3, padding=1)self.myselu = MYSELU(3)self.conv.weight.data.fill_(1)self.conv.bias.data.fill_(0)def forward(self, x):x = self.conv(x)x = self.myselu(x)return x
pytorch__64">2、现有算子上封装pytorch 自定义算子(实现动态放大超分辨率模型)

实现动态放大超分辨率模型。我们希望实现:
forward(self, x, upscale_factor)
这样一个接口,x 为图像输入,upscale_factor为动态放大倍数。
pytorch 现有放大算子有nn.Upsample 和 interpolate, 但是nn.Upsample 在初始化阶段固化了放大倍数,而 PyTorch 的 interpolate 插值算子可以在运行阶段选择放大倍数。

python">class SuperResolutionNet(nn.Module): def forward(self, x, upscale_factor): x = interpolate(x, scale_factor=upscale_factor.item(), mode='bicubic', align_corners=False) 
... 
# Inference 
# Note that the second input is torch.tensor(3) 
torch_output = model(torch.from_numpy(input_img), torch.tensor(3)).detach().numpy() 
... 
with torch.no_grad(): torch.onnx.export(model, (x, torch.tensor(3)), "srcnn2.onnx", opset_version=11, input_names=['input', 'factor'], output_names=['output']) 

尝试使用以上方法导出onnx 时,虽然没有报错能成功导出onnx,但是有TraceWarning 的警告,说明导出onnx有追踪失败。这是由于我们使用了 torch.Tensor.item() 把数据从 Tensor 里取出来,而导出 ONNX 模型时这个操作是无法被记录的,只好报了一条 TraceWarning。

因此我们需要自定义算子,让onnx在追踪时刻能work。我们看到nn.Upsample 和 interpolate在转onnx时都映射到了Resize 操作。所以自定义算子在Resize 操作上进行封装即可。
在这里插入图片描述

Resize 操作有三个输入,x, roi, scale, 我们就是要动态输入scale。展开 scales,可以看到 scales 是一个长度为 4 的一维张量,其内容为 [1, 1, 3, 3],
如果我们能够自己生成一个 ONNX 的 Resize 算子,让 scales 成为一个可变量而不是常量,就像它上面的 X 一样,那这个超分辨率模型就能动态缩放了。

python">import torch 
from torch import nn 
from torch.nn.functional import interpolate 
import torch.onnx 
import cv2 
import numpy as np 
class NewInterpolate(torch.autograd.Function): @staticmethod def symbolic(g, input, scales): return g.op("Resize", input, g.op("Constant", value_t=torch.tensor([], dtype=torch.float32)), scales, coordinate_transformation_mode_s="pytorch_half_pixel", cubic_coeff_a_f=-0.75, mode_s='cubic', nearest_mode_s="floor") @staticmethod def forward(ctx, input, scales): scales = scales.tolist()[-2:] return interpolate(input, scale_factor=scales, mode='bicubic', align_corners=False) class StrangeSuperResolutionNet(nn.Module): def __init__(self): super().__init__() self.conv1 = nn.Conv2d(3, 64, kernel_size=9, padding=4) self.conv2 = nn.Conv2d(64, 32, kernel_size=1, padding=0) self.conv3 = nn.Conv2d(32, 3, kernel_size=5, padding=2) self.relu = nn.ReLU() def forward(self, x, upscale_factor): x = NewInterpolate.apply(x, upscale_factor) out = self.relu(self.conv1(x)) out = self.relu(self.conv2(out)) out = self.conv3(out) return out 

以上自定义了Resize 算子,将scale 作为算子的一个输入,最后还是调用interpolate。但是scale已经变成自定义输入参数。
参数映射如下:
在这里插入图片描述


http://www.ppmy.cn/devtools/136720.html

相关文章

计算机网络:概述知识点及习题练习

网课资源: 湖科大教书匠 1、因特网 网络之间需要路由器进行互联,互联网是网络的网络,因特网是最大的互联网,连接到网络的设备称为主机,一般不叫路由器为主机。 因特网发展:ARPNET->三级结构因特网&am…

项目学习:仿b站的视频网站项目03-注册功能

概括 通过上一期,完成了项目和数据库的基础结构的搭建,接下来主要是完成项目的注册功能。该功能模块主要分为有两个接口,一个是验证码接口,一个是注册接口。 让我们开始吧! 验证码接口 验证码的生成主要配合下面这…

Python自动化测试实践中pytest用到的功能dependency和parametrize

Python自动化测试中pytest用到的功能 1、pytest之@pytest.mark.dependency装饰器设置测试用例之间的依赖关系 1.1说明: 1、这是一个pytest第三方插件,主要解决用例之间的依赖关系。如果依赖的上下文测试用例失败后续的用例会被标识为跳过执行,相当于执行了 pytest.mark.s…

#渗透测试#红蓝攻防#HW#SRC漏洞挖掘01之静态页面渗透

免责声明 本教程仅为合法的教学目的而准备,严禁用于任何形式的违法犯罪活动及其他商业行为,在使用本教程前,您应确保该行为符合当地的法律法规,继续阅读即表示您需自行承担所有操作的后果,如有异议,请立即停…

SpringBoot 集成 html2Pdf

一、概述&#xff1a; 1. springboot如何生成pdf&#xff0c;接口可以预览可以下载 2. vue下载通过bold如何下载 3. 一些细节&#xff1a;页脚、页眉、水印、每一页得样式添加 二、直接上代码【主要是一个记录下次开发更快】 模板位置 1. 导入pom包 <dependency><g…

docker使用阿里云容器镜像服务下载公共镜像

写在当所有的加速镜像源都失效之后…… 1&#xff0c;登录阿里云进入容器镜像服务 阿里云登录 - 欢迎登录阿里云&#xff0c;安全稳定的云计算服务平台 2&#xff0c;搜索开源的镜像 3&#xff0c;查看下载连接 4&#xff0c;拉取镜像 5&#xff0c;自己建私服 阿里镜像仓库…

机器学习实战记录(1)

决策树——划分数据集 def splitDataSet(dataSet, axis, value): retDataSet [] #创建返回的数据集列表for featVec in dataSet: #遍历数据集if featVec[axis] value:reducedFeatVec featVec[:axis] #去掉axis特征reducedFeatVec.extend(featVec[axis1…

【Golang】手搓DES加密

代码非常长 有六百多行 参考一位博主的理论实现 通俗易懂&#xff0c;十分钟读懂DES 还有很多不足的地方 感觉只是个思路 S盒&#xff08;理论既定&#xff09; package src// 定义S - 盒的置换表 var SBoxes [8][4][16]int{{{14, 4, 13, 1, 2, 15, 11, 8, 3, 10, 6, 12, …