基于MNIST的手写数字识别

server/2025/2/13 0:04:59/

上次我们基于CIFAR-10训练一个图像分类器,梳理了一下训练模型的全过程,并且对卷积神经网络有了一定的理解,我们再在GPU上搭建一个手写的数字识别cnn网络,加深巩固一下

步骤

  1. 加载数据集
  2. 定义神经网络
  3. 定义损失函数
  4. 训练网络
  5. 测试网络

MNIST数据集简介

MINIST是一个手写数字数据库(官网地址:http://yann.lecun.com/exdb/mnist/),它有6w张训练样本和1w张测试样本,每张图的像素尺寸为28*28,如下图一共4个图片,这些图片文件均被保存为二进制格式

训练全过程

1.加载数据集

import torch
import torchvision
from torchvision import transforms
trainset = torchvision.datasets.MNIST(root='./data',train=True,download=True,transform=transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.1307,), (0.3081,))]))
trainloader = torch.utils.data.DataLoader(trainset,batch_size=64,shuffle=True)testset = torchvision.datasets.MNIST('./data', train=False, transform=transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.1307,), (0.3081,))
]))
test_loader = torch.utils.data.DataLoader(testset, batch_size=64, shuffle=True)

展示一些训练图片

import numpy as np
import matplotlib.pyplot as plt
def imshow(img):img = img / 2 + 0.5     # unnormalizenpimg = img.numpy()plt.imshow(np.transpose(npimg, (1, 2, 0)))plt.show()
# 得到batch中的数据
dataiter = iter(train_loader)
images, labels = dataiter.next()imshow(torchvision.utils.make_grid(images))

2.定义卷积神经网络

import torch
import torch.nn as nn
import torch.nn.functional as F#可以调用一些常见的函数,例如非线性以及池化等
class Net(nn.Module):def __init__(self):super(Net, self).__init__()# input image channel, 6 output channels, 5x5 square convolutionself.conv1 = nn.Conv2d(1, 6, 5)self.conv2 = nn.Conv2d(6, 16, 5)# 全连接 从16 * 4 * 4的维度转成120self.fc1 = nn.Linear(16 * 4 * 4, 120)self.fc2 = nn.Linear(120, 84)self.fc3 = nn.Linear(84, 10)def forward(self, x):x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))x = F.max_pool2d(F.relu(self.conv2(x)), 2)#(2,2)也可以直接写成数字2x = x.view(-1, self.num_flat_features(x))#将维度转成以batch为第一维 剩余维数相乘为第二维x = F.relu(self.fc1(x))x = F.relu(self.fc2(x))x = self.fc3(x)return xdef num_flat_features(self, x):size = x.size()[1:]  # 第一个维度batch不考虑num_features = 1for s in size:num_features *= sreturn num_features
net = Net()
print(net)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)
net.to(device)

3.定义损失和优化器

criterion = nn.CrossEntropyLoss()
import torch.optim as optim
optimizer = optim.SGD(net.parameters(), lr=0.01, momentum=0.9)

这里设置了 momentum=0.9 ,训练一轮的准确率由90%提到了98%

4.训练网络

def train(epochs):net.train()for epoch in range(epochs):running_loss = 0.0for i, data in enumerate(trainloader):# 得到输入 和 标签inputs, labels = datainputs, labels = inputs.to(device), labels.to(device)# 消除梯度optimizer.zero_grad()# 前向传播 计算损失 后向传播 更新参数outputs = net(inputs)loss = criterion(outputs, labels)loss.backward()optimizer.step()# 打印日志running_loss += loss.item()if i % 100 == 0:    # 每100个batch打印一次print('[%d, %5d] loss: %.3f' %(epoch + 1, i + 1, running_loss / 100))running_loss = 0.0
torch.save(net, 'mnist.pth')

net.train():调用方法时,模型将进入训练模式。在训练模式下,一些特定的模块,例如Dropout和Batch Normalization,将被启用。这是因为在训练过程中,我们需要使用Dropout来防止过拟合,并使用Batch Normalization来加速收敛

net.eval():调用方法时,模型将进入评估模式。在评估模式下,一些特定的模块,例如Dropout和Batch Normalization,将被禁用。这是因为在评估过程中,我们不需要使用Dropout来防止过拟合,并且Batch Normalization的统计信息应该是固定的。

5.测试网络

在其它地方导入模型测试时需要将类的定义添加到加载模型的这个py文件中

from mnist.py import Net  # 导入会运行mnist.py
net = torch.load('mnist.pth')testset = torchvision.datasets.MNIST('./data', train=False, transform=transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.1307,), (0.3081,))
]))
testloader = torch.utils.data.DataLoader(testset, batch_size=64, shuffle=True)correct = 0
total = 0
net.to('cpu') 
print(net)with torch.no_grad():  # 或者model.eval()for data in testloader:images, labels = dataoutputs = net(images)_, predicted = torch.max(outputs.data, 1)total += labels.size(0)correct += (predicted == labels).sum().item()
print('Accuracy of the network on the 10000 test images: %d %%' % (100 * correct / total))

训练一轮速度

GPU:10s

CPU:10s

训练三轮速度

GPU:24.5s

CPU:28.6s

得出结论:训练数据计算量少的时候,无论在CPU上还是GPU,性能几乎都是接近的,而当训练数据计算量达到一定多的时候,GPU的优势就比较显著直观了

小小实验:

(1)加载并测试一张图片,正确则输出True

import torch
import torch.nn as nn
import torchvision
from torchvision import transforms
import torch.nn.functional as F
import cv2
import numpy as npclass Net(nn.Module):def __init__(self):super(Net, self).__init__()self.conv1 = nn.Conv2d(1, 6, 5)self.conv2 = nn.Conv2d(6, 16, 5)self.fc1 = nn.Linear(16 * 4 * 4, 120)self.fc2 = nn.Linear(120, 84)self.fc3 = nn.Linear(84, 10)def forward(self, x):x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))x = F.max_pool2d(F.relu(self.conv2(x)), 2)  x = x.view(-1, self.num_flat_features(x))  x = F.relu(self.fc1(x))x = F.relu(self.fc2(x))x = self.fc3(x)return xdef num_flat_features(self, x):size = x.size()[1:]  num_features = 1for s in size:num_features *= sreturn num_featurescorrect = 0
total = 0
net = torch.load('mnist.pth')
net.to('cpu')
# print(net)with torch.no_grad(): imgdir = '3.jpeg'img = cv2.imread(imgdir, 0)img = cv2.resize(img, (28, 28))trans = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.1307,), (0.3081,))])image = trans(img)image = image.unsqueeze(0)label = torch.tensor([int(imgdir.split('.')[0])])outputs = net(image)_, predicted = torch.max(outputs.data, 1)print(predicted)print((predicted == label).item())

拿刚刚训练的模型试了6张数字图片,只有一张2是预测对的....

unsuqeeze:通过unsuqeeze(int)中的int整数,增加一个维度,int整数表示维度增加到哪儿去,且维度为1,参数:【0, 1, 2】


http://www.ppmy.cn/server/13334.html

相关文章

(2022级)成都工业学院数据库原理及应用实验六: SQL DML(增、删、改)

写在前面 1、基于2022级软件工程/计算机科学与技术实验指导书 2、成品仅提供参考 3、如果成品不满足你的要求,请寻求其他的途径 运行环境 window11家庭版 Navicat Premium 16 Mysql 8.0.36 实验要求 在实验三的基础上完成下列查询: 1、在科室表…

Redis篇:缓存更新策略最佳实践

前景: 缓存更新是redis为了节约内存而设计出来的一个东西,主要是因为内存数据宝贵,当我们向redis插入太多数据,此时就可能会导致缓存中的数据过多,所以redis会对部分数据进行更新,或者把他叫为淘汰更合适&a…

多商家AI智能名片商城系统(开源版)——构建高效数字化商业新生态

一、项目概述 1、项目背景 1)起源 随着数字化时代的快速发展,传统名片和商城系统已经难以满足企业日益增长的需求。商家需要更高效、更智能的方式来展示自己的产品和服务,与消费者进行互动和交易。同时,开源技术的普及也为开发…

Spring MVC和Spring Boot

上节已经提到过请求,这次梳理响应。 响应 响应基本上都要被Controller所托管,告诉Spring帮我们管理这个代码,我们在后面需要访问时,才可以进行访问,否则将会报错。并且其是由RestController分离出来的,Re…

免杀技术之白加黑的攻击防御

一、介绍 1. 什么是白加黑 通俗的讲白加黑中的白就是指被杀软列入到可信任列表中的文件。比如说微软自带的系统文件或者一些有有效证书签名的文件,什么是微软文件,或者什么是有效签名文件在后面我们会提到他的辨别方法。黑就是指我们自己的文件,没有有…

OpenHarmony多媒体-GSYVideoPlayer

简介 GSYVideoPlayer是一个视频播放器库,支持切换内核播放器(IJKPlayer、avplayer),并且支持了多种能力。 效果展示: 下载安装 ohpm install ohos/gsyvideoplayerOpenHarmony ohpm 环境配置等更多内容,请…

PSA Group EDI 需求分析

PSA集团(以下简称PSA)中文名为标致雪铁龙集团,是一家法国私营汽车制造公司,致力于为全球消费者提供独具特色的汽车体验和自由愉悦的出行方案,旗下拥有标致、雪铁龙、DS、欧宝、沃克斯豪尔五大汽车品牌。 汽车制造企业对…

【JavaScript】axios

基础使用 <script src"https://cdn.bootcdn.net/ajax/libs/axios/1.5.0/axios.min.js"></script> <script>axios.get(https://study.duyiedu.com/api/herolist).then(res> {console.log(res.data)}) </script>get - params <script s…