pytorch入门级项目--基于卷积神经网络的数字识别

devtools/2025/2/24 13:10:25/

文章目录

  • 前言
  • 1.数据集的介绍
  • 2.数据集的准备
  • 3.数据集的加载
  • 4.自定义网络模型
    • 4.1卷积操作
    • 4.2池化操作
    • 4.3模型搭建
  • 5.模型训练
    • 5.1选择损失函数和优化器
    • 5.2训练
  • 6.模型的保存
  • 7.模型的验证
  • 结语

前言

本篇博客主要针对pytorch入门级的教程,实现了一个基于卷积神经网络(CNN)架构的数字识别,带你了解由数据集到模型验证的全过程。

1.数据集的介绍

MNIST数据集是一个广泛用于机器学习和计算机视觉领域的手写数字图像数据集。它是深度学习入门的经典数据集,常用于图像识别任务,特别是手写数字识别。该数据集分为训练集和测试集:

  • 训练集:包含60,000张手写数字图像,用于模型的训练。
  • 测试集:包含10,000张手写数字图像,用于模型的评估。

2.数据集的准备

Torchvision在torchvision. datasets模块中提供了许多内置数据集,其中便包含了MNIST数据集,因此可以直接通过torchvision. datasets直接下载。

  • root代表存储路径,此处采用的是相对路径,保存在当前文件所在文件夹的data文件夹下
  • train用来区分训练集和测试集
  • download用来表示是否下载数据集

同时我们可以查看其中一条数据,看看该数据及具体形式

import torchvision
trainset=torchvision.datasets.MNIST(root='./data',train=True,download=True)
testset=torchvision.datasets.MNIST(root='./data',train=False,download=True)
trainset[0]

在这里插入图片描述
这里我的数据集已经下载好了,所以很快执行好了
(<PIL.Image.Image image mode=L size=28x28>, 5)通过输出我们可以发现,第一条数据是一个元组,第一个元素表示data(灰度图,大小28*28),第二个元素表示label,指的是该图片的类别,下面我们可以查看每个类别的含义
在这里插入图片描述
因此我们可以看到第一张图片表示的就是手写数字5,也可以通过trainset[0][0].show()进行显示
在这里插入图片描述
在进行后续操作前,需要将图片格式由PIL调整为tensor类型。

import torchvision
trainset=torchvision.datasets.MNIST(root='./data',train=True,download=True,transform=torchvision.transforms.ToTensor())
testset=torchvision.datasets.MNIST(root='./data',train=False,download=True,transform=torchvision.transforms.ToTensor())
trainset[0][0].shape

在这里插入图片描述

至此,数据集准备阶段就算完成了。

3.数据集的加载

在训练深度学习模型时,通常不会一次性将整个数据集输入到模型中,而是将数据集分成多个小批量(mini-batches)进行训练。

import torchtrainloader=torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True, num_workers=2)
testloader=torch.utils.data.DataLoader(testset,batch_size=64,shuffle=False)
type(trainloader),len(trainloader),len(trainset)
  • batch_size:表示批量大小
  • shuffle:表示是否打乱,一般训练集打乱,测试集打乱不打乱都可以
  • num_workers:表示多个进程,windows系统可能报错
    这里提供解决windows系统中多进程报错的可能解决方案:
    加上if __name__ == '__main__':即可,如果你是用的是jupyter,不用此操作

在这里插入图片描述
此时,我们取出一条数据查看类型:
在这里插入图片描述
此时我们可以看到,一条数据为torch.Size([64, 1, 28, 28]),64表示批量大小,1表示该图像单通道,即为灰度图,28*28表示图像大小
我们可以通过tensorboard,查看当前批量的图片

from torch.utils.tensorboard import SummaryWriterwriter=SummaryWriter("./logs")
steps=0
for data in trainloader:# print(data[0].shape)# breakimages,labels=datawriter.add_images('mnist_images',images,steps)steps+=1
writer.close()

在这里插入图片描述
此处在环境中需要装tensorboard

conda install tensorboard

或者

pip install tensorboard

然后在终端执行

tensorboard --logdir=logs

即可

4.自定义网络模型

这里我们随便选了一张网络结构图,也可以自行设计
在这里插入图片描述
来源:图片来源

4.1卷积操作

这里简单介绍一下卷积的运算方式
在这里插入图片描述
在这里插入图片描述

4.2池化操作

池化层是子采样的一种具体实现方式,这里我们介绍最大池化
在这里插入图片描述

4.3模型搭建

  1. 由图可知,第一层卷积,输入通道数为1(原始图像为灰度图),输出通道数为6,因为图像大小未发生改变,所以padding=2
    在这里插入图片描述
  2. 第二层池化,本文使用的是最大池化
  3. 第三层卷积,输入通道数为6(上一层卷积的输出),输出通道数为16(由图),因为图像大小未发生改变,所以padding=2
  4. 第四层池化,使用最大池化
  5. 全连接层,图示定义了两层
  6. 定义前向传播函数
  7. 输出网络结构
class Net(torch.nn.Module):def __init__(self):super(Net,self).__init__()# 第一层卷积层self.conv1=torch.nn.Conv2d(in_channels=1,out_channels=6,kernel_size=5,padding=2)# 第一层池化层self.pool1=torch.nn.MaxPool2d(kernel_size=2,stride=2)# 第二层卷积层self.conv2=torch.nn.Conv2d(in_channels=6,out_channels=16,kernel_size=5,padding=2)# 第二层池化层self.pool2=torch.nn.MaxPool2d(kernel_size=2,stride=2)# 全连接层self.fc1=torch.nn.Linear(in_features=16*7*7,out_features=84)self.fc2=torch.nn.Linear(in_features=84,out_features=10)def forward(self,x):x=self.pool1(torch.nn.functional.relu(self.conv1(x)))x=self.pool2(torch.nn.functional.relu(self.conv2(x)))x=x.view(-1,16*7*7)x=torch.nn.functional.relu(self.fc1(x))x=self.fc2(x)return x
net=Net()
net

在这里插入图片描述

5.模型训练

5.1选择损失函数和优化器

对于分类问题,一般选用交叉熵损失CrossEntropyLoss(),优化器此处我们选用随机梯度下降SGD

5.2训练

我们先对训练集进行一轮训练

running_loss=0
for i,data in enumerate(trainloader,1):inputs,labels=dataoutputs=net(inputs)loss=criterion(outputs,labels)running_loss+=lossoptimizer.zero_grad()loss.backward()optimizer.step()if i%100==0:print('epoch:{} loss:{}'.format(i,running_loss/100))running_loss=0

在这里插入图片描述
我们发现loss持续下降,我们训练十轮

running_loss=0
for epoch in range(10):for i,data in enumerate(trainloader,1):inputs,labels=dataoutputs=net(inputs)loss=criterion(outputs,labels)running_loss+=lossoptimizer.zero_grad()loss.backward()optimizer.step()if i%100==0:print('[%d,%5d] loss:%.3f'%(epoch+1,i,running_loss/100))running_loss=0.0
print('Finished Training')

在这里插入图片描述

6.模型的保存

在这里插入图片描述

7.模型的验证

此处我们选择test数据集一个批次验证
在这里插入图片描述
计算整体上的正确率

correct=0
total=0
with torch.no_grad():for data in testloader:images,labels=dataoutputs=net2(images)_,predicted=torch.max(outputs.data,1)total+=labels.size(0)correct+=(predicted==labels).sum().item()
print('Accuracy of the network on the 10000 test images: %d %%' % (100 * correct / total))

在这里插入图片描述

结语

本篇博客通过训练卷积神经网络CNN模型,实现了对数字的识别,希望对你有所帮助


http://www.ppmy.cn/devtools/161367.html

相关文章

2025保险与金融领域实战全解析:DeepSeek赋能细分领域深度指南(附全流程案例)

🚀 2025保险与金融领域实战全解析:DeepSeek赋能细分领域深度指南(附全流程案例)🚀 📚 目录 DeepSeek在保险与金融中的核心价值保险领域:从风险建模到产品创新金融领域:从投资分析到财富管理区块链与联邦学习的应用探索客户关系与私域运营:全球化体验升级工具与资源…

WordPress Elementor提示错误无法保存500的解决指南

500内部服务器错误是一种常见的服务器错误&#xff0c;通常由网站的服务器环境引起。这种错误可能导致网站无法正常访问&#xff0c;影响用户体验。本文将探讨500错误的常见原因&#xff0c;并提供解决方案&#xff0c;特别针对使用Elementor构建的WordPress网站。 500错误的常…

多源BFS(典型算法思想)—— OJ例题算法解析思路

目录 一、542. 01 矩阵 - 力扣&#xff08;LeetCode&#xff09; 算法代码&#xff1a; 代码逻辑思路 数据结构初始化 步骤一&#xff1a;队列初始化 步骤二&#xff1a;广度优先搜索 返回结果 关键点总结 广度优先搜索&#xff08;BFS&#xff09; 访问标记 复杂度…

如何将MySQL数据库迁移至阿里云

将 MySQL 数据库迁移至阿里云可以通过几种不同的方法&#xff0c;具体选择哪种方式取决于你的数据库大小、数据复杂性以及对迁移速度的需求。阿里云提供了多种迁移工具和服务&#xff0c;本文将为你介绍几种常见的方法。 方法一&#xff1a;使用 阿里云数据库迁移服务 (DTS) 阿…

挪车小程序挪车二维码php+uniapp

一款基于FastAdminThinkPHP开发的匿名通知车主挪车微信小程序&#xff0c;采用匿名通话的方式&#xff0c;用户只能在有效期内拨打车主电话&#xff0c;过期失效&#xff0c;从而保护车主和用户隐私。提供微信小程序端和服务端源码&#xff0c;支持私有化部署。 更新日志 V1.0…

基于RISC-V内核完全自主可控国产化MCU芯片

国科安芯MCU芯片采用开放、灵活的RISC-V指令集架构&#xff0c;RISC-V的开源特性不仅大幅降低研发成本&#xff0c;更赋予芯片设计高度定制化能力。例如&#xff0c;国科安芯的AS32S601抗辐照MCU基于32位RV32IMZicsr指令集&#xff0c;主频达180MHz&#xff0c;内置2MB Flash与…

小波变换背景预测matlab和python样例

小波变换使用matlab和python 注意1d和2d的函数区别。注意默认参数问题。最终三个版本结果能够对齐。 matlab load(wave_in.mat)% res: image of 1536 x 1536 th1; dlevel7; wavenamedb6;[m,n] wavedec2(res, dlevel, wavename);vec zeros(size(m)); vec(1:n(1)*n(1)*1) m…

AI汽车新风向:「死磕」AI底盘,引爆线控底盘新增长拐点

2025开年&#xff0c;DeepSeek火爆出圈&#xff0c;包括吉利、东风汽车、上汽、广汽、长城、长安、比亚迪等车企相继官宣接入&#xff0c;掀起了“AI定义汽车”浪潮。 而这股最火的AI汽车热潮&#xff0c;除了深度赋能智能座舱、智能驾驶等AI竞争更白热化的细分场景&#xff0…