PyTorch 模型转换为 ONNX 格式

news/2024/12/3 5:29:35/

PyTorch 模型转换为 ONNX 格式

在深度学习领域,模型的可移植性和可解释性是非常重要的。本文将介绍如何使用 PyTorch 训练一个简单的卷积神经网络(CNN)来分类 MNIST 数据集,并将训练好的模型转换为 ONNX 格式。我们还将讨论 PTH 和 ONNX 格式的区别,并介绍如何使用 Netron 可视化 ONNX 模型。

1. PTH 和 ONNX 的区别

PTH 格式

  • 定义:PTH 是 PyTorch 框架的专有格式,通常用于保存模型的状态字典(state_dict),包括模型的结构和训练好的参数。

  • 兼容性

    • PTH 文件只能在 PyTorch 中使用,无法直接在 C++ 环境中加载。虽然 PyTorch 提供了 C++ API(LibTorch),但 PTH 文件的加载和使用主要依赖于 Python 环境。
    • 在 C++ 中使用 PTH 文件需要将模型转换为 PyTorch 的 C++ 格式,这可能会增加复杂性和开发时间。
  • 用途

    • PTH 格式适合在 Python 环境中进行模型训练和调试,但在 C++ 中进行模型部署时,通常需要将模型转换为其他格式(如 ONNX)以便于跨平台使用。
    • 在 C++ 中,使用 PTH 文件的灵活性较低,尤其是在需要与其他框架或系统集成时。

ONNX 格式

  • 定义:ONNX(Open Neural Network Exchange)是一个开放的深度学习模型交换格式,旨在促进不同深度学习框架之间的互操作性。

  • 兼容性

    • ONNX 文件可以在多个深度学习框架中使用,包括 PyTorch、TensorFlow、Caffe2 等,这使得它在 C++ 环境中的兼容性更强。
    • ONNX 模型可以通过 ONNX Runtime、TensorRT、OpenVINO 等推理引擎在 C++ 中高效运行,支持多种硬件加速。
  • 用途

    • ONNX 格式非常适合模型的部署和推理,特别是在需要跨平台或跨框架使用时。它允许开发者在 C++ 中轻松加载和运行模型,而无需依赖于 Python 环境。
    • 在 C++ 中,使用 ONNX 模型可以简化工程化流程,便于与其他系统集成,提升模型的可移植性和可扩展性。

总结

在 C++ 进行深度学习模型的工程化时,选择 ONNX 格式通常更为合适,因为它提供了更好的跨平台兼容性和灵活性。PTH 格式虽然在 PyTorch 环境中非常方便,但在 C++ 中的使用受到限制,通常需要额外的转换步骤。ONNX 的开放性和广泛支持使其成为在多种环境中部署深度学习模型的首选格式。

2. 训练 MNIST 数据集的 CNN 模型

以下是使用 PyTorch 训练 MNIST 数据集的完整代码示例:

python">import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
from torchvision import datasets
from torch.utils.data import DataLoader# 检查是否支持 MPS
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
print(f"Using device: {device}")# 1. 数据加载
transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5,), (0.5,))  # MNIST 数据集的均值和标准差
])# 下载 MNIST 数据集
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)train_loader = DataLoader(dataset=train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(dataset=test_dataset, batch_size=64, shuffle=False)# 2. 定义 CNN 模型
class SimpleCNN(nn.Module):def __init__(self):super(SimpleCNN, self).__init__()self.conv1 = nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1)  # 输入通道为1,输出通道为32self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)  # 输入通道为32,输出通道为64self.pool = nn.MaxPool2d(kernel_size=2, stride=2)  # 最大池化层self.fc1 = nn.Linear(64 * 7 * 7, 128)  # 全连接层self.fc2 = nn.Linear(128, 10)  # 输出层def forward(self, x):x = self.pool(torch.relu(self.conv1(x)))  # 第一层卷积 + 激活 + 池化x = self.pool(torch.relu(self.conv2(x)))  # 第二层卷积 + 激活 + 池化x = x.view(x.size(0), -1)  # 展平输入x = torch.relu(self.fc1(x))  # 第一个全连接层x = self.fc2(x)  # 输出层return x# 3. 训练模型
model = SimpleCNN().to(device)  # 将模型移动到 MPS 设备
criterion = nn.CrossEntropyLoss()  # 损失函数
optimizer = optim.Adam(model.parameters(), lr=0.001)  # 优化器# 训练过程
num_epochs = 5
for epoch in range(num_epochs):model.train()for images, labels in train_loader:images, labels = images.to(device), labels.to(device)  # 将数据移动到 MPS 设备optimizer.zero_grad()  # 清空梯度outputs = model(images)  # 前向传播loss = criterion(outputs, labels)  # 计算损失loss.backward()  # 反向传播optimizer.step()  # 更新参数print(f'Epoch [{epoch + 1}/{num_epochs}], Loss: {loss.item():.4f}')# 4. 评估模型
model.eval()
correct = 0
total = 0
with torch.no_grad():for images, labels in test_loader:images, labels = images.to(device), labels.to(device)  # 将数据移动到 MPS 设备outputs = model(images)_, predicted = torch.max(outputs.data, 1)  # 获取预测结果total += labels.size(0)correct += (predicted == labels).sum().item()print(f'Accuracy of the model on the test images: {100 * correct / total:.2f}%')# 5. 转换为 ONNX 格式
onnx_file_path = 'mnist_cnn_model.onnx'
dummy_input = torch.randn(1, 1, 28, 28).to(device)  # 示例输入,形状为 [batch_size, channels, height, width]
torch.onnx.export(model, dummy_input, onnx_file_path, export_params=True,opset_version=11, do_constant_folding=True,input_names=['input'], output_names=['output'],dynamic_axes={'input': {0: 'batch_size'}, 'output': {0: 'batch_size'}})print(f'Model has been converted to ONNX format and saved as {onnx_file_path}.')

3. 使用 Netron 可视化 ONNX 模型

一旦您将模型转换为 ONNX 格式,您可以使用 Netron 来可视化模型结构。Netron 是一个开源的模型可视化工具,支持多种深度学习框架的模型文件格式,包括 ONNX。

使用步骤:
  1. 下载 Netron

    • 您可以访问 Netron 的官方网站 在线使用,或者下载桌面版本。
  2. 打开 ONNX 模型

    • 如果使用在线版本,直接将 mnist_cnn_model.onnx 文件拖放到浏览器窗口中。
    • 如果使用桌面版本,打开 Netron 应用,选择“File” > “Open Model”,然后选择您的 ONNX 文件。
  3. 查看模型结构

    • 在 Netron 中,您可以查看模型的层次结构、输入输出形状、参数数量等信息。通过可视化,您可以更好地理解模型的设计和工作原理。
      在这里插入图片描述

http://www.ppmy.cn/news/1551919.html

相关文章

telnet IP某个端口,但是ping不通IP :网络连接中的不同境遇

在网络探索的旅程中,我们常常会遇到一些看似矛盾的现象,比如能够 Telnet 一个 IP 的端口,却 Ping 不通这个 IP。这究竟是为何呢? 以往,我曾天真地认为,Telnet 通了,Ping 肯定也是通的。毕竟&am…

【Linux探索学习】第十八弹——进程等待:深入解析操作系统中的进程等待机制

Linux学习笔记:https://blog.csdn.net/2301_80220607/category_12805278.html?spm1001.2014.3001.5482 前言: 在Linux操作系统中,进程是资源的管理和执行单元,每个进程都有其自己的生命周期。在进程的执行过程中,进程…

C++关于二叉树的具体实现

目录 1.二叉树的结构 2.创建一棵二叉树 3.二叉树的先序遍历 1.借助栈的先序遍历 2.利用递归的先序遍历 4.二叉树的中序遍历 5.二叉树的后序遍历 1.借助栈的后序遍历 2.利用递归的后序遍历 6.二叉树的层序遍历 7.tree.h 8.tree.cpp 9.main.cpp 1.二叉树的结构 对于…

为什么redis用跳表不用b+树,而mysql用b+树而不是跳表?

写在前面 上一篇文章中,我们深度解析了redis中的跳表结构,而b树的结构我们很久之前就讲过了,那么我们知道了redis的有序集合用的是跳表,而mysql的innodb引擎用的是b树存储,但这是为什么呢?为什么redis用跳…

【新人系列】Python 入门(十四):文件操作

✍ 个人博客:https://blog.csdn.net/Newin2020?typeblog 📝 专栏地址:https://blog.csdn.net/newin2020/category_12801353.html 📣 专栏定位:为 0 基础刚入门 Python 的小伙伴提供详细的讲解,也欢迎大佬们…

HCIE IGP双栈综合实验

实验拓扑 实验需求及解法 本实验模拟ISP网络结构,R1/2组成国家骨干网,R3/4组成省级网络,R5/6/7组成数据中 心网络。 配置所有ipv4地址,请自行测试直连。 R1 sysname R1 interface GigabitEthernet0/0/0ip address 12.1.1.1 255.…

Flink 离线计算

文章目录 一、样例一&#xff1a;读 csv 文件生成 csv 文件二、样例二&#xff1a;读 starrocks 写 starrocks三、样例三&#xff1a;DataSet、Table Sql 处理后写入 StarRocks四、遇到的坑 <dependency><groupId>org.apache.flink</groupId><artifactId&…

C++11新增特性2

一.lambda 1.本质&#xff1a;lambda对象是⼀个匿名函数对象&#xff0c;它可以定义在函数内部。 注&#xff1a;lambda表达式语法使⽤层⽽⾔没有类型&#xff0c;所以我们⼀般是⽤auto或者模板参数定义的对象去接收lambda对象。 2.表达式&#xff1a;[capture-list] (param…