使用 PyTorch 实现并测试 AlexNet 模型,并使用 TensorRT 进行推理加速

embedded/2024/11/14 4:33:12/

本篇文章详细介绍了如何使用 PyTorch 实现经典卷积神经网络 AlexNet,并利用 Fashion-MNIST 数据集进行训练与测试。在训练完成后,通过 TensorRT 进行推理加速,以提升模型的推理效率。
本文全部代码链接:全部代码下载

环境配置

为了保证代码在 GPU 环境下顺利运行,我们将安装兼容 CUDA 11.3 的 PyTorch 版本。请使用以下命令安装 PyTorch、Torchvision 和 Torchaudio:

python">!pip install torch==1.12.0+cu113 torchvision==0.13.0+cu113 torchaudio==0.12.0 --extra-index-url https://download.pytorch.org/whl/cu113

为确保兼容性,还可以使用特定版本的 numpy:

python">!pip install numpy==1.23.0

数据加载与预处理

我们将使用 torchvision.datasets.FashionMNIST 加载 Fashion-MNIST 数据集,并对数据进行标准化处理。

将图像转换为张量
归一化图像到 [-1, 1]

python">from torchvision import datasets, transforms
from torch.utils.data import DataLoader

定义数据预处理

python">transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5,), (0.5,))  # 归一化到 [-1, 1]
])

加载数据集

python">train_dataset = datasets.FashionMNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.FashionMNIST(root='./data', train=False, download=True, transform=transform)

定义数据加载器

python">train_loader = DataLoader(dataset=train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(dataset=test_dataset, batch_size=64, shuffle=False)

AlexNet 模型定义

AlexNet 是一种包含 5 层卷积层和 3 层全连接层的经典深度卷积神经网络。以下代码展示了如何使用 PyTorch 实现 AlexNet 的结构。

python">import torch.nn as nn
import torch.nn.functional as Fclass AlexNet(nn.Module):def __init__(self):super(AlexNet, self).__init__()self.conv1 = nn.Conv2d(in_channels=1, out_channels=96, kernel_size=11, stride=4, padding=1)self.maxpool1 = nn.MaxPool2d(kernel_size=3, stride=2)self.conv2 = nn.Conv2d(in_channels=96, out_channels=256, kernel_size=5, stride=1, padding=2)self.maxpool2 = nn.MaxPool2d(kernel_size=3, stride=2

http://www.ppmy.cn/embedded/137411.html

相关文章

Flink Source 详解

Flink Source 详解 原文 flip-27 FLIP-27 介绍了新版本Source 接口定义及架构 相比于SourceFunction,新版本的Source更具灵活性,原因是将“splits数据获取”与真“正数据获取”逻辑进行了分离 重要部件 Source 作为工厂类,会创建以下两…

Django博客网站上线前准备事项

Django博客网站上线前准备事项 1. 功能完善与测试 确保博客网站具备以下基础功能,并且经过充分测试: 用户认证:注册、登录、登出、密码重置。文章管理:文章的创建、编辑、发布、删除。分类与标签:文章分类和标签的管…

【css】overflow: hidden效果

1. 不添加overflow: hidden 1.1 效果 上面无圆角 1.2 代码 <template><view class"parent"><view class"child1">child1</view><view class"child2">child2</view></view></template><…

YOLOv11实战宠物狗分类

本文采用YOLOv11作为核心算法框架&#xff0c;结合PyQt5构建用户界面&#xff0c;使用Python3进行开发。YOLOv11以其高效的特征提取能力&#xff0c;在多个图像分类任务中展现出卓越性能。本研究针对5种宠物狗数据集进行训练和优化&#xff0c;该数据集包含丰富的宠物狗图像样本…

【Linux:IO多路复用(select函数)

什么是IO多路复用&#xff1f; 一种网络通信的手段&#xff0c;IO多路复用可以同时监测多个文件描述符&#xff0c;且这个过程是阻塞的&#xff0c;当检测有文件描述符就绪&#xff0c;程序的阻塞就会解除&#xff0c;就可以通过这些就绪的文件描述符进行通信。通过这种方式在…

使用YOLOv9进行图像与视频检测

大家好&#xff0c;YOLOv9 与其前身v8一样&#xff0c;专注于识别和精确定位图像和视频中的对象。本文将介绍如何使用YOLOv9进行图像与视频检测&#xff0c;自动驾驶汽车、安全系统和高级图像搜索等应用在很大程度上依赖于此功能&#xff0c;YOLOv9 引入了比 YOLOv8 更令人印象…

hbase集成phoenix

1.环境 环境准备 三台节点zookeeper三节点hadoop三节点hbase三节点 2.pheonix集成 官网下载地址&#xff0c;需挂梯子&#xff0c;使用官网推荐的对应hbase版本即可 https://phoenix.apache.org/download.html下载及解压 wget https://dlcdn.apache.org/phoenix/phoenix-…

npm list @types/node 命令用于列出当前项目中 @types/node 包及其依赖关系

文章目录 作用示例常用选项示例命令注意事项 1、实战举例**解决方法**1. **锁定唯一的 types/node 版本**2. **清理依赖并重新安装**3. **设置 tsconfig.json 的 types**4. **验证 Promise 类型支持** **总结** npm list types/node 命令用于列出当前项目中 types/node 包及其…