pytorch入门级项目--基于卷积神经网络的数字识别

news/2025/2/24 18:43:19/

文章目录

  • 前言
  • 1.数据集的介绍
  • 2.数据集的准备
  • 3.数据集的加载
  • 4.自定义网络模型
    • 4.1卷积操作
    • 4.2池化操作
    • 4.3模型搭建
  • 5.模型训练
    • 5.1选择损失函数和优化器
    • 5.2训练
  • 6.模型的保存
  • 7.模型的验证
  • 结语

前言

本篇博客主要针对pytorch入门级的教程,实现了一个基于卷积神经网络(CNN)架构的数字识别,带你了解由数据集到模型验证的全过程。

1.数据集的介绍

MNIST数据集是一个广泛用于机器学习和计算机视觉领域的手写数字图像数据集。它是深度学习入门的经典数据集,常用于图像识别任务,特别是手写数字识别。该数据集分为训练集和测试集:

  • 训练集:包含60,000张手写数字图像,用于模型的训练。
  • 测试集:包含10,000张手写数字图像,用于模型的评估。

2.数据集的准备

Torchvision在torchvision. datasets模块中提供了许多内置数据集,其中便包含了MNIST数据集,因此可以直接通过torchvision. datasets直接下载。

  • root代表存储路径,此处采用的是相对路径,保存在当前文件所在文件夹的data文件夹下
  • train用来区分训练集和测试集
  • download用来表示是否下载数据集

同时我们可以查看其中一条数据,看看该数据及具体形式

import torchvision
trainset=torchvision.datasets.MNIST(root='./data',train=True,download=True)
testset=torchvision.datasets.MNIST(root='./data',train=False,download=True)
trainset[0]

在这里插入图片描述
这里我的数据集已经下载好了,所以很快执行好了
(<PIL.Image.Image image mode=L size=28x28>, 5)通过输出我们可以发现,第一条数据是一个元组,第一个元素表示data(灰度图,大小28*28),第二个元素表示label,指的是该图片的类别,下面我们可以查看每个类别的含义
在这里插入图片描述
因此我们可以看到第一张图片表示的就是手写数字5,也可以通过trainset[0][0].show()进行显示
在这里插入图片描述
在进行后续操作前,需要将图片格式由PIL调整为tensor类型。

import torchvision
trainset=torchvision.datasets.MNIST(root='./data',train=True,download=True,transform=torchvision.transforms.ToTensor())
testset=torchvision.datasets.MNIST(root='./data',train=False,download=True,transform=torchvision.transforms.ToTensor())
trainset[0][0].shape

在这里插入图片描述

至此,数据集准备阶段就算完成了。

3.数据集的加载

在训练深度学习模型时,通常不会一次性将整个数据集输入到模型中,而是将数据集分成多个小批量(mini-batches)进行训练。

import torchtrainloader=torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True, num_workers=2)
testloader=torch.utils.data.DataLoader(testset,batch_size=64,shuffle=False)
type(trainloader),len(trainloader),len(trainset)
  • batch_size:表示批量大小
  • shuffle:表示是否打乱,一般训练集打乱,测试集打乱不打乱都可以
  • num_workers:表示多个进程,windows系统可能报错
    这里提供解决windows系统中多进程报错的可能解决方案:
    加上if __name__ == '__main__':即可,如果你是用的是jupyter,不用此操作

在这里插入图片描述
此时,我们取出一条数据查看类型:
在这里插入图片描述
此时我们可以看到,一条数据为torch.Size([64, 1, 28, 28]),64表示批量大小,1表示该图像单通道,即为灰度图,28*28表示图像大小
我们可以通过tensorboard,查看当前批量的图片

from torch.utils.tensorboard import SummaryWriterwriter=SummaryWriter("./logs")
steps=0
for data in trainloader:# print(data[0].shape)# breakimages,labels=datawriter.add_images('mnist_images',images,steps)steps+=1
writer.close()

在这里插入图片描述
此处在环境中需要装tensorboard

conda install tensorboard

或者

pip install tensorboard

然后在终端执行

tensorboard --logdir=logs

即可

4.自定义网络模型

这里我们随便选了一张网络结构图,也可以自行设计
在这里插入图片描述
来源:图片来源

4.1卷积操作

这里简单介绍一下卷积的运算方式
在这里插入图片描述
在这里插入图片描述

4.2池化操作

池化层是子采样的一种具体实现方式,这里我们介绍最大池化
在这里插入图片描述

4.3模型搭建

  1. 由图可知,第一层卷积,输入通道数为1(原始图像为灰度图),输出通道数为6,因为图像大小未发生改变,所以padding=2
    在这里插入图片描述
  2. 第二层池化,本文使用的是最大池化
  3. 第三层卷积,输入通道数为6(上一层卷积的输出),输出通道数为16(由图),因为图像大小未发生改变,所以padding=2
  4. 第四层池化,使用最大池化
  5. 全连接层,图示定义了两层
  6. 定义前向传播函数
  7. 输出网络结构
class Net(torch.nn.Module):def __init__(self):super(Net,self).__init__()# 第一层卷积层self.conv1=torch.nn.Conv2d(in_channels=1,out_channels=6,kernel_size=5,padding=2)# 第一层池化层self.pool1=torch.nn.MaxPool2d(kernel_size=2,stride=2)# 第二层卷积层self.conv2=torch.nn.Conv2d(in_channels=6,out_channels=16,kernel_size=5,padding=2)# 第二层池化层self.pool2=torch.nn.MaxPool2d(kernel_size=2,stride=2)# 全连接层self.fc1=torch.nn.Linear(in_features=16*7*7,out_features=84)self.fc2=torch.nn.Linear(in_features=84,out_features=10)def forward(self,x):x=self.pool1(torch.nn.functional.relu(self.conv1(x)))x=self.pool2(torch.nn.functional.relu(self.conv2(x)))x=x.view(-1,16*7*7)x=torch.nn.functional.relu(self.fc1(x))x=self.fc2(x)return x
net=Net()
net

在这里插入图片描述

5.模型训练

5.1选择损失函数和优化器

对于分类问题,一般选用交叉熵损失CrossEntropyLoss(),优化器此处我们选用随机梯度下降SGD

5.2训练

我们先对训练集进行一轮训练

running_loss=0
for i,data in enumerate(trainloader,1):inputs,labels=dataoutputs=net(inputs)loss=criterion(outputs,labels)running_loss+=lossoptimizer.zero_grad()loss.backward()optimizer.step()if i%100==0:print('epoch:{} loss:{}'.format(i,running_loss/100))running_loss=0

在这里插入图片描述
我们发现loss持续下降,我们训练十轮

running_loss=0
for epoch in range(10):for i,data in enumerate(trainloader,1):inputs,labels=dataoutputs=net(inputs)loss=criterion(outputs,labels)running_loss+=lossoptimizer.zero_grad()loss.backward()optimizer.step()if i%100==0:print('[%d,%5d] loss:%.3f'%(epoch+1,i,running_loss/100))running_loss=0.0
print('Finished Training')

在这里插入图片描述

6.模型的保存

在这里插入图片描述

7.模型的验证

此处我们选择test数据集一个批次验证
在这里插入图片描述
计算整体上的正确率

correct=0
total=0
with torch.no_grad():for data in testloader:images,labels=dataoutputs=net2(images)_,predicted=torch.max(outputs.data,1)total+=labels.size(0)correct+=(predicted==labels).sum().item()
print('Accuracy of the network on the 10000 test images: %d %%' % (100 * correct / total))

在这里插入图片描述

结语

本篇博客通过训练卷积神经网络CNN模型,实现了对数字的识别,希望对你有所帮助


http://www.ppmy.cn/news/1574683.html

相关文章

网络安全:防范NetBIOS漏洞的攻击

稍微懂点电脑知识的朋友都知道&#xff0c;NetBIOS 是计算机局域网领域流行的一种传输方式&#xff0c;但你是否还知道&#xff0c;对于连接互联网的机器来讲&#xff0c;NetBIOS是一大隐患。 漏洞描述 NetBIOS(Network Basic Input Output System&#xff0c;网络基本输入输…

【SpringBoot教程】SpringBoot整合Caffeine本地缓存及Spring Cache注解的使用

&#x1f64b;大家好&#xff01;我是毛毛张! &#x1f308;个人首页&#xff1a; 神马都会亿点点的毛毛张 毛毛张今天要介绍的是本地缓存之王&#xff01;Caffeine&#xff01;SpringBoot整合Caffeine本地缓存及Spring Cache注解的使用 文章目录 1.Caffeine本地缓存1.1 本地…

【学习笔记】【SpringCloud】MybatisPlus 基础使用

目录 一、使用 MybatisPlus 基本步骤 1. 引入 MybatisPlus 依赖 2. 定义Mapper接口并继承BaseMapper 二、MybatisPlus 常用配置 三、自定义SQL 四、IService 接口 1. 批量新增的效率问题 2. 配置方式 五、插件功能 1. 分页插件 一、使用 MybatisPlus 基本步骤 1. 引…

OSPF基础知识总结

基本概念 协议类型:链路状态型IGP(内部网关协议),基于Dijkstra算法计算最短路径树。 协议号:IP层协议,协议号89。 特点:支持分层设计(区域划分)、快速收敛、无环路、支持VLSM/CIDR。 区域(Area) 骨干区域(Backbone Area):Area 0,所有非骨干区域必须直接或通过虚…

第四届图像、信号处理与模式识别国际学术会议(ISPP 2025)

重要信息 大会官网&#xff1a;www.icispp.com 大会时间&#xff1a;2025年3月28日-30日 大会地点&#xff1a;南京 简介 由河海大学和江苏大学联合主办的第四届图像、信号处理与模式识别&#xff08;ISPP 2025) 将于2025年3月28日-30日在中国南京举行。主要围绕图像信号处…

C#项目04——递归求和

实现逻辑 利用递归&#xff0c;求取1~N以内的和 知识点 正常情况下&#xff0c;C#每条线程都会分配1MB的地址空间&#xff0c;因此执行递归的层次不能太深&#xff0c;否则就会出现溢出的风险&#xff0c; 业务设计 程序代码 private void button1_Click(object sender, E…

Mac OS JAVA_HOME设置

个人博客地址&#xff1a;Mac OS JAVA_HOME设置 | 一张假钞的真实世界 在MacOS上使用DMG文件安装了Jdk8 之后&#xff0c;在默认路径下找不到JDK的HOME路径&#xff1a; $ which java /usr/bin/java $ ls -l /usr/bin/java lrwxr-xr-x 1 root wheel 74 12 6 2015 /usr/b…

力扣hot100 ——搜索二维矩阵 || m+n复杂度优化解法

编写一个高效的算法来搜索 m x n 矩阵 matrix 中的一个目标值 target 。该矩阵具有以下特性&#xff1a; 每行的元素从左到右升序排列。每列的元素从上到下升序排列。 解题思路&#xff1a; 借助行和列有序特性&#xff0c;不断按行或者列缩小范围&#xff1b;途中数字表示每…