【深度学习】pytorch pth模型转为onnx模型后出现冗余节点“identity”,onnx模型的冗余节点“identity”

news/2024/11/15 7:32:38/

情況描述

onnx模型的冗余节点“identity”如下图。
在这里插入图片描述

解决方式

首先,确保您已经安装了onnx-simplifier库:

pip install onnx-simplifier

然后,您可以按照以下方式使用onnx-simplifier库:

import onnx
from onnxsim import simplify# 加载导出的 ONNX 模型
onnx_model = onnx.load("your_model.onnx")# 简化模型
simplified_model, check = simplify(onnx_model)# 保存简化后的模型
onnx.save_model(simplified_model, "simplified_model.onnx")

通过这个过程,onnx-simplifier库将会检测和移除不必要的"identity"节点,从而减少模型中的冗余。

请注意,使用onnx-simplifier库可能会改变模型的计算图,因此在使用简化后的模型之前,务必进行测试和验证以确保其功能没有受到影响。

问题原因

在将 PyTorch 模型转换为 ONNX 格式时,有时会出现冗余的"identity"节点的问题。这是因为 PyTorch 和 ONNX 在计算图构建和表示方式上存在一些差异。

在 PyTorch 中,计算图是动态构建的,其中包含了很多临时变量和操作。但在 ONNX 中,计算图是静态定义的,每个操作都显式地表示为一个节点。这种差异可能导致在将 PyTorch 模型转换为 ONNX 格式时引入一些不必要的中间"identity"节点。

一个常见的原因是,PyTorch 中的某些操作或模型结构在 ONNX 中没有直接的等价表示。为了保持模型结构的一致性,转换过程中可能会引入额外的"identity"节点,用于保留原始模型中的特定计算图结构或操作。

另外,有时候这些"identity"节点并不会对模型的性能或功能产生任何影响,它们只是在图形表示上引入了一些冗余。这些冗余节点在模型尺寸较小的情况下可能并不明显,但对于大型模型来说可能会显著增加模型文件的大小。

通过使用onnx-simplifier库,您可以对导出的 ONNX 模型进行后处理,去除这些不必要的"identity"节点,从而减少模型的冗余。

需要注意的是,由于 PyTorch 和 ONNX 之间的差异,无法完全避免所有的冗余节点。但大部分情况下这些冗余节点并不会对模型的性能或功能产生实质性的影响。

我的模型代码

import torch
from torch import nn
import torch.nn.functional as F
from torch.nn import initclass hswish(nn.Module):def forward(self, x):out = x * F.relu6(x + 3, inplace=True) / 6return outclass hsigmoid(nn.Module):def forward(self, x):out = F.relu6(x + 3, inplace=True) / 6return out# 注意力机制
class SeModule(nn.Module):def __init__(self, in_channel, reduction=4):super(SeModule, self).__init__()self.avgpool = nn.AdaptiveAvgPool2d(1)self.fc1 = nn.Conv2d(in_channel, in_channel // reduction, kernel_size=1, stride=1, padding=0, bias=False)self.bn = nn.BatchNorm2d(in_channel // reduction)self.relu = nn.ReLU(inplace=True)self.fc2 = nn.Conv2d(in_channel // reduction, in_channel, kernel_size=1, stride=1, padding=0, bias=False)self.hs = hsigmoid()def forward(self, x):out = self.avgpool(x)out = self.fc1(out)out = self.bn(out)out = self.relu(out)out = self.fc2(out)out = self.hs(out)return x * out# 线性瓶颈和反向残差结构
class Block(nn.Module):def __init__(self, kernel_size, in_channel, expand_size, out_channel, nolinear, semodule, stride):super(Block, self).__init__()self.stride = strideself.se = semodule# 1*1展开卷积self.conv1 = nn.Conv2d(in_channel, expand_size, kernel_size=1, stride=1, padding=0, bias=False)self.bn1 = nn.BatchNorm2d(expand_size)self.nolinear1 = nolinear# 3*3(或5*5)深度可分离卷积self.conv2 = nn.Conv2d(expand_size, expand_size, kernel_size=kernel_size, stride=stride,padding=kernel_size // 2, groups=expand_size, bias=False)self.bn2 = nn.BatchNorm2d(expand_size)self.nolinear2 = nolinear# 1*1投影卷积self.conv3 = nn.Conv2d(expand_size, out_channel, kernel_size=1, stride=1, padding=0, bias=False)self.bn3 = nn.BatchNorm2d(out_channel)self.shortcut = nn.Sequential()if stride == 1 and in_channel != out_channel:self.shortcut = nn.Sequential(nn.Conv2d(in_channel, out_channel, kernel_size=1, stride=1, padding=0, bias=False),nn.BatchNorm2d(out_channel),)def forward(self, x):out = self.nolinear1(self.bn1(self.conv1(x)))out = self.nolinear2(self.bn2(self.conv2(out)))out = self.bn3(self.conv3(out))# 注意力模块if self.se != None:out = self.se(out)# 残差链接out = out + self.shortcut(x) if self.stride == 1 else outreturn outclass MobileNetV3_Small_050(nn.Module):def __init__(self):super(MobileNetV3_Small_050, self).__init__()self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=2, padding=1, bias=False)self.bn1 = nn.BatchNorm2d(16)self.hs1 = nn.ReLU(inplace=True)self.bneck = nn.Sequential(Block(3, 16, 8, 16, nn.ReLU(inplace=True), SeModule(16), 2),Block(3, 16, 40, 16, nn.ReLU(inplace=True), None, 2),Block(3, 16, 56, 16, nn.ReLU(inplace=True), None, 1),Block(5, 16, 64, 24, hswish(), SeModule(24), 2),Block(5, 24, 144, 24, hswish(), SeModule(24), 1),Block(5, 24, 144, 24, hswish(), SeModule(24), 1),Block(5, 24, 72, 24, hswish(), SeModule(24), 1),Block(5, 24, 72, 24, hswish(), SeModule(24), 1),Block(5, 24, 144, 48, hswish(), SeModule(48), 2),Block(5, 48, 288, 48, hswish(), SeModule(48), 1),Block(5, 48, 288, 48, hswish(), SeModule(48), 1),)self.conv2 = nn.Conv2d(48, 288, kernel_size=1, stride=1, padding=0, bias=False)self.bn2 = nn.BatchNorm2d(288)self.hs2 = hswish()self.avgpool = nn.AdaptiveAvgPool2d(1)self.fc = nn.Linear(288, 6)self.init_params()def init_params(self):for m in self.modules():if isinstance(m, nn.Conv2d):init.kaiming_normal_(m.weight, mode='fan_out')if m.bias is not None:init.constant_(m.bias, 0)elif isinstance(m, nn.BatchNorm2d):init.constant_(m.weight, 1)init.constant_(m.bias, 0)elif isinstance(m, nn.Linear):init.normal_(m.weight, std=0.001)if m.bias is not None:init.constant_(m.bias, 0)def forward(self, x):out = self.hs1(self.bn1(self.conv1(x)))out = self.bneck(out)out = self.hs2(self.bn2(self.conv2(out)))out = self.avgpool(out)out = out.view(-1, 288)out = self.fc(out)return outclass MobileNetV3_Small(nn.Module):def __init__(self):super(MobileNetV3_Small, self).__init__()self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=2, padding=1, bias=False)self.bn1 = nn.BatchNorm2d(16)self.hs1 = hswish()self.bneck = nn.Sequential(Block(3, 16, 16, 16, nn.ReLU(inplace=True), SeModule(16), 2),Block(3, 16, 72, 24, nn.ReLU(inplace=True), None, 2),Block(3, 24, 88, 24, nn.ReLU(inplace=True), None, 1),Block(5, 24, 96, 40, hswish(), SeModule(40), 2),Block(5, 40, 240, 40, hswish(), SeModule(40), 1),Block(5, 40, 240, 40, hswish(), SeModule(40), 1),Block(5, 40, 120, 48, hswish(), SeModule(48), 1),Block(5, 48, 144, 48, hswish(), SeModule(48), 1),Block(5, 48, 288, 96, hswish(), SeModule(96), 2),Block(5, 96, 576, 96, hswish(), SeModule(96), 1),Block(5, 96, 576, 96, hswish(), SeModule(96), 1),)self.conv2 = nn.Conv2d(96, 576, kernel_size=1, stride=1, padding=0, bias=False)self.bn2 = nn.BatchNorm2d(576)self.hs2 = hswish()self.avgpool = nn.AdaptiveAvgPool2d(1)self.fc = nn.Linear(576, 6)self.init_params()def init_params(self):for m in self.modules():if isinstance(m, nn.Conv2d):init.kaiming_normal_(m.weight, mode='fan_out')if m.bias is not None:init.constant_(m.bias, 0)elif isinstance(m, nn.BatchNorm2d):init.constant_(m.weight, 1)init.constant_(m.bias, 0)elif isinstance(m, nn.Linear):init.normal_(m.weight, std=0.001)if m.bias is not None:init.constant_(m.bias, 0)def forward(self, x):out = self.hs1(self.bn1(self.conv1(x)))out = self.bneck(out)out = self.hs2(self.bn2(self.conv2(out)))out = self.avgpool(out)out = out.view(-1, 576)out = self.fc(out)return outif __name__ == '__main__':# from torchsummary import summary# net = MobileNetV3_Small_050().train()# summary(net, (3, 64, 64))## from torchstat import stat# net = MobileNetV3_Small_050().train()# stat(net, input_size=(3, 64, 64))  # 输出模型的FLOPs和参数数量# 转为onnximport torch.onnxdummy_input = torch.randn(1, 3, 64, 64)net = MobileNetV3_Small_050().eval()torch.onnx.export(net, dummy_input, "mobilenetv3_small_050.onnx", input_names=["input"], output_names=["output"],opset_version=11, )

http://www.ppmy.cn/news/348346.html

相关文章

WPS vbe6ex.olb 不能加载

WPS vbe6ex.olb 不能加载折腾经历 一直用MS OFFICE, 偶然试用了WPS,一见钟情不能自拔.XLWINGS插件,VBA宏都能正常使用. 本人有强迫症,见不得电脑上有偷懒的东西存在, 于是果断卸载MS OFFICE. 于是问题来了,当XLWINGS需要import functions时,提示vbe6ex.olb 不能加载.虽然只要重…

vb.net开发vbe插件,在vbe界面生成类似任务窗格的窗体

要在 VB.NET 中开发 VBE 插件并生成类似任务窗格的窗体,您需要做以下几件事: 安装 Microsoft Visual Studio 开发环境。新建一个 VB.NET 项目,选择模板为 "VB.NET 自定义控件库"。使用 ToolStripContainer 控件创建一个包含 ToolSt…

VBE窗口组成

打开Excel应用程序,在"开发工具"选项卡下点击"Visual Basic"按钮或者按下快捷键组合AltF11即可打开VBE窗口,其主要成分如图1-1所示。 图1-1 VBE窗口 窗口中主要包括菜单栏、工具栏、工程资源管理器、代码窗口、属性窗口、立即窗口、…

熟悉VBA的编程环境---VBE

打开EXCEL应用程序,在“开发工具”选项卡下点击“Visual Basic”按钮或者直接按下快捷键组合ALTF11即打开VBE窗口,其主要组成部分如图: 窗口中主要包括菜单栏、工具栏、工程资源管理器、代码窗口、属性窗口、立即窗口、监视窗口等。这些窗口…

微信一天可以加多少个好友?

微信作为最大的私域流量池,几乎所有的人都会往微信引流,而微信每天加好友数量是有严格限制的。微信每天加多少人不会封号?微信每天加多少好友才不会被限制?微信频繁加好友被限制怎么办?请跟随小编的脚步一起往下看吧。…

Protocol https not supported or disabled in libcurl

原因 curl默认安装完后是只支持http协议而不支持https协议的。 curl -V查看当前curl支持哪些协议: [rootlocalhost /]# curl -V curl 7.19.4 (x86_64-unknown-linux-gnu) libcurl/7.19.4 OpenSSL/1.0.2k zlib/1.2.11 Protocols: tftp ftp telnet dict http fil…

Vue3-03-Vue2 响应式 VS Vue3 响应式

本文来讲解从 Vue2 到 Vue3 响应式底层的一些改变。 前言 Vue 2.x 为什么不监听数组下标索引值的变化? 参考了很多博主的推文,自己也尝试了一下,Object.defineProperty 是可以做到监听数组的索引值的变化的,来做 getter 和 sette…

【摄像头】摄像机工作原理

【目录】郭老二博文之:图像视频汇总 1、摄像机工作原理 外部光线穿过镜头(lens)后, 经过滤光片(color filter)滤波后照射到光学传感器(Sensor)上面, Sensor 将从 lens 上传导过来的光线转换为电信号,再通过内部的 AD 转换为数字…