pytorch入门级项目--基于卷积神经网络的数字识别

embedded/2025/2/24 17:17:19/

文章目录

  • 前言
  • 1.数据集的介绍
  • 2.数据集的准备
  • 3.数据集的加载
  • 4.自定义网络模型
    • 4.1卷积操作
    • 4.2池化操作
    • 4.3模型搭建
  • 5.模型训练
    • 5.1选择损失函数和优化器
    • 5.2训练
  • 6.模型的保存
  • 7.模型的验证
  • 结语

前言

本篇博客主要针对pytorch入门级的教程,实现了一个基于卷积神经网络(CNN)架构的数字识别,带你了解由数据集到模型验证的全过程。

1.数据集的介绍

MNIST数据集是一个广泛用于机器学习和计算机视觉领域的手写数字图像数据集。它是深度学习入门的经典数据集,常用于图像识别任务,特别是手写数字识别。该数据集分为训练集和测试集:

  • 训练集:包含60,000张手写数字图像,用于模型的训练。
  • 测试集:包含10,000张手写数字图像,用于模型的评估。

2.数据集的准备

Torchvision在torchvision. datasets模块中提供了许多内置数据集,其中便包含了MNIST数据集,因此可以直接通过torchvision. datasets直接下载。

  • root代表存储路径,此处采用的是相对路径,保存在当前文件所在文件夹的data文件夹下
  • train用来区分训练集和测试集
  • download用来表示是否下载数据集

同时我们可以查看其中一条数据,看看该数据及具体形式

import torchvision
trainset=torchvision.datasets.MNIST(root='./data',train=True,download=True)
testset=torchvision.datasets.MNIST(root='./data',train=False,download=True)
trainset[0]

在这里插入图片描述
这里我的数据集已经下载好了,所以很快执行好了
(<PIL.Image.Image image mode=L size=28x28>, 5)通过输出我们可以发现,第一条数据是一个元组,第一个元素表示data(灰度图,大小28*28),第二个元素表示label,指的是该图片的类别,下面我们可以查看每个类别的含义
在这里插入图片描述
因此我们可以看到第一张图片表示的就是手写数字5,也可以通过trainset[0][0].show()进行显示
在这里插入图片描述
在进行后续操作前,需要将图片格式由PIL调整为tensor类型。

import torchvision
trainset=torchvision.datasets.MNIST(root='./data',train=True,download=True,transform=torchvision.transforms.ToTensor())
testset=torchvision.datasets.MNIST(root='./data',train=False,download=True,transform=torchvision.transforms.ToTensor())
trainset[0][0].shape

在这里插入图片描述

至此,数据集准备阶段就算完成了。

3.数据集的加载

在训练深度学习模型时,通常不会一次性将整个数据集输入到模型中,而是将数据集分成多个小批量(mini-batches)进行训练。

import torchtrainloader=torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True, num_workers=2)
testloader=torch.utils.data.DataLoader(testset,batch_size=64,shuffle=False)
type(trainloader),len(trainloader),len(trainset)
  • batch_size:表示批量大小
  • shuffle:表示是否打乱,一般训练集打乱,测试集打乱不打乱都可以
  • num_workers:表示多个进程,windows系统可能报错
    这里提供解决windows系统中多进程报错的可能解决方案:
    加上if __name__ == '__main__':即可,如果你是用的是jupyter,不用此操作

在这里插入图片描述
此时,我们取出一条数据查看类型:
在这里插入图片描述
此时我们可以看到,一条数据为torch.Size([64, 1, 28, 28]),64表示批量大小,1表示该图像单通道,即为灰度图,28*28表示图像大小
我们可以通过tensorboard,查看当前批量的图片

from torch.utils.tensorboard import SummaryWriterwriter=SummaryWriter("./logs")
steps=0
for data in trainloader:# print(data[0].shape)# breakimages,labels=datawriter.add_images('mnist_images',images,steps)steps+=1
writer.close()

在这里插入图片描述
此处在环境中需要装tensorboard

conda install tensorboard

或者

pip install tensorboard

然后在终端执行

tensorboard --logdir=logs

即可

4.自定义网络模型

这里我们随便选了一张网络结构图,也可以自行设计
在这里插入图片描述
来源:图片来源

4.1卷积操作

这里简单介绍一下卷积的运算方式
在这里插入图片描述
在这里插入图片描述

4.2池化操作

池化层是子采样的一种具体实现方式,这里我们介绍最大池化
在这里插入图片描述

4.3模型搭建

  1. 由图可知,第一层卷积,输入通道数为1(原始图像为灰度图),输出通道数为6,因为图像大小未发生改变,所以padding=2
    在这里插入图片描述
  2. 第二层池化,本文使用的是最大池化
  3. 第三层卷积,输入通道数为6(上一层卷积的输出),输出通道数为16(由图),因为图像大小未发生改变,所以padding=2
  4. 第四层池化,使用最大池化
  5. 全连接层,图示定义了两层
  6. 定义前向传播函数
  7. 输出网络结构
class Net(torch.nn.Module):def __init__(self):super(Net,self).__init__()# 第一层卷积层self.conv1=torch.nn.Conv2d(in_channels=1,out_channels=6,kernel_size=5,padding=2)# 第一层池化层self.pool1=torch.nn.MaxPool2d(kernel_size=2,stride=2)# 第二层卷积层self.conv2=torch.nn.Conv2d(in_channels=6,out_channels=16,kernel_size=5,padding=2)# 第二层池化层self.pool2=torch.nn.MaxPool2d(kernel_size=2,stride=2)# 全连接层self.fc1=torch.nn.Linear(in_features=16*7*7,out_features=84)self.fc2=torch.nn.Linear(in_features=84,out_features=10)def forward(self,x):x=self.pool1(torch.nn.functional.relu(self.conv1(x)))x=self.pool2(torch.nn.functional.relu(self.conv2(x)))x=x.view(-1,16*7*7)x=torch.nn.functional.relu(self.fc1(x))x=self.fc2(x)return x
net=Net()
net

在这里插入图片描述

5.模型训练

5.1选择损失函数和优化器

对于分类问题,一般选用交叉熵损失CrossEntropyLoss(),优化器此处我们选用随机梯度下降SGD

5.2训练

我们先对训练集进行一轮训练

running_loss=0
for i,data in enumerate(trainloader,1):inputs,labels=dataoutputs=net(inputs)loss=criterion(outputs,labels)running_loss+=lossoptimizer.zero_grad()loss.backward()optimizer.step()if i%100==0:print('epoch:{} loss:{}'.format(i,running_loss/100))running_loss=0

在这里插入图片描述
我们发现loss持续下降,我们训练十轮

running_loss=0
for epoch in range(10):for i,data in enumerate(trainloader,1):inputs,labels=dataoutputs=net(inputs)loss=criterion(outputs,labels)running_loss+=lossoptimizer.zero_grad()loss.backward()optimizer.step()if i%100==0:print('[%d,%5d] loss:%.3f'%(epoch+1,i,running_loss/100))running_loss=0.0
print('Finished Training')

在这里插入图片描述

6.模型的保存

在这里插入图片描述

7.模型的验证

此处我们选择test数据集一个批次验证
在这里插入图片描述
计算整体上的正确率

correct=0
total=0
with torch.no_grad():for data in testloader:images,labels=dataoutputs=net2(images)_,predicted=torch.max(outputs.data,1)total+=labels.size(0)correct+=(predicted==labels).sum().item()
print('Accuracy of the network on the 10000 test images: %d %%' % (100 * correct / total))

在这里插入图片描述

结语

本篇博客通过训练卷积神经网络CNN模型,实现了对数字的识别,希望对你有所帮助


http://www.ppmy.cn/embedded/164867.html

相关文章

HTML 字符实体

HTML 字符实体 概述 HTML字符实体是一种用于在HTML文档中表示特殊字符的方法。在HTML中&#xff0c;一些字符&#xff08;如<、>、&等&#xff09;具有特殊意义&#xff0c;不能直接使用在文本内容中。为了解决这个问题&#xff0c;HTML提供了一套字符实体来替代这…

遗传算法初探

组成要素 编码 分为二进制编码、实数编码和顺序编码 初始种群的产生 分为随机方法、基于反向学习优化的种群产生。 基于反向学习优化的种群其思想是先随机生成一个种群P(N)&#xff0c;然后按照反向学习方法生成新的种群OP(N),合并两个种群&#xff0c;得到一个新的种群S(N…

洛谷P9241 [蓝桥杯 2023 省 B] 飞机降落

题目描述 N 架飞机准备降落到某个只有一条跑道的机场。其中第 i 架飞机在 Ti​ 时刻到达机场上空&#xff0c;到达时它的剩余油料还可以继续盘旋 Di​ 个单位时间&#xff0c;即它最早可以于 Ti​ 时刻开始降落&#xff0c;最晩可以于 Ti​Di​ 时刻开始降落。降落过程需要 Li…

【开源项目】分布式文本多语言翻译存储平台

分布式文本多语言翻译存储平台 地址&#xff1a; Gitee&#xff1a;https://gitee.com/dreamPointer/zza-translation/blob/master/README.md 一、提供服务 分布式文本翻译服务&#xff0c;长文本翻译支持流式回调&#xff08;todo&#xff09;分布式文本多语言翻译结果存储服…

MacOS下使用Ollama本地构建DeepSeek并使用本地Dify构建AI应用

目录 1 大白话说一下文章内容2 作者的电脑配置3 DeepSeek的本地部署3.1 Ollamal的下载和安装3.2 选择合适的deepseek模型3.3 安转deepseek 4 DifyDeepSeek构建Al应用4.1 Dify的安装4.1.1 前置条件4.1.2 拉取代码4.1.3 启动Dify 4.2 Dify控制页面4.3 使用Dify实现个“文章标题生…

基于springboot+vue的考研互助平台

开发语言&#xff1a;Java框架&#xff1a;springbootJDK版本&#xff1a;JDK1.8服务器&#xff1a;tomcat7数据库&#xff1a;mysql 5.7&#xff08;一定要5.7版本&#xff09;数据库工具&#xff1a;Navicat11开发软件&#xff1a;eclipse/myeclipse/ideaMaven包&#xff1a;…

hydra docker版本

最近做ssh暴力破解实验&#xff0c;由于服务器上面软件依赖太乱了&#xff0c;导致我花了好久没能成功编译出hydra&#xff0c;于是想到了使用docker版本的hydra&#xff0c;最后成功的完成了ssh暴力破解实验&#xff5e; ailx10 1958 次咨询 网络安全优秀回答者 互联网行业…

从0开始的操作系统手搓教程 6 ——检查我们的内存信息

目录 如何检测内存 使用0xE820办法获取所有的内存 使用E801办法获取我们的内存 大保底&#xff1a;使用0x88功能 开始实现 修正我们的跳转偏移 开始内存检查 下一篇 现在&#xff0c;我们可以进一步迈向我们的内核了。那就是加载我们的内核。刚刚我们说过&#xff0c;加…