手写数字识别案例分析(torch,深度学习入门)

embedded/2024/9/24 8:47:59/

人工智能机器学习的广阔领域中,手写数字识别是一个经典的入门级问题,它不仅能够帮助我们理解深度学习的基本原理,还能作为实践编程和模型训练的良好起点。本文将带您踏上手写数字识别的深度学习之旅,从数据集介绍、模型构建到训练与评估,一步步深入探索。

一、引言

手写数字识别(Handwritten Digit Recognition)是指通过计算机程序自动识别手写数字的过程。最著名的手写数字数据集之一是MNIST(Modified National Institute of Standards and Technology database),它包含了大量的手写数字图片,每张图片都被标记了对应的数字(0-9)。这个数据集成为了初学者学习深度学习,尤其是卷积神经网络(CNN)的首选。

二、MNIST数据集简介

MNIST数据集由60,000个训练样本和10,000个测试样本组成,每个样本都是一张28x28像素的灰度图像,代表了一个手写数字。这些图像已经被归一化并居中在图像中心,使得数字不会受到位置变化的影响。

 PyTorch 和 torchvision 库来下载并准备 MNIST 数据集,包括训练集和测试集

import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor'''下载训练数据集(图片+标签)'''
training_data = datasets.MNIST(root="data",train=True,download=True,transform=ToTensor()
)
test_data = datasets.MNIST(root="data",train=False,download=True,transform=ToTensor()
)
  1. 打印设备信息:您的代码已经很好地检查了CUDA和MPS(针对Apple M系列芯片)的可用性,并设置了相应的设备。但是,在打印设备信息时,有一个小错误在字符串格式化上。您需要确保在字符串中正确地包含变量名。

  2. 打印数据形状:您已经正确地设置了DataLoader并打印了测试数据集中的一个批次的数据和标签的形状。这是一个很好的实践,可以帮助您了解数据的维度。

train_dataloader = DataLoader(training_data, batch_size=64, shuffle=True)  # 通常训练时会打乱数据  
test_dataloader = DataLoader(test_data, batch_size=64, shuffle=False)  # 测试时不需要打乱数据  # 打印测试数据集的一个批次的数据和标签的形状  
for x, y in test_dataloader:  print(f"Shape of x [N,C,H,W]: {x.shape}")  # 注意这里的x是图像,但MNIST是灰度图,所以C=1  print(f"Shape of y: {y.shape}, {y.dtype}")  # y是标签,通常是一维的,且为long类型  break  # 判断当前设备是否支持GPU,其中mps是苹果m系列芯片的GPU  
device = "cuda" if torch.cuda.is_available() else ('mps' if torch.backends.mps.is_available() else "cpu")  
print(f"Using {device} device")  # 确保在字符串中正确地包含了变量名  

三、训练模型选择

一、创建一个具有多个隐藏层的神经网络,这些层都使用了nn.Linear来定义全连接层,并使用torch.sigmoid作为激活函数。

import torch  
import torch.nn as nn  class NeuralNetwork(nn.Module):  def __init__(self):  super().__init__()  self.flatten = nn.Flatten()  self.hidden1 = nn.Linear(28 * 28, 256)  self.relu1 = nn.ReLU()  self.hidden2 = nn.Linear(256, 128)  self.relu2 = nn.ReLU()  self.hidden3 = nn.Linear(128, 64)  self.relu3 = nn.ReLU()  self.hidden4 = nn.Linear(64, 32)  self.relu4 = nn.ReLU()  self.out = nn.Linear(32, 10)  # 输出层对应于10个类别的得分  def forward(self, x):x = self.flatten(x)x = self.hidden1(x)x = torch.sigmoid(x)x = self.hidden2(x)x = torch.sigmoid(x)x = self.hidden3(x)x = torch.sigmoid(x)x = self.hidden4(x)x = torch.sigmoid(x)x = self.out(x)return x model = NeuralNetwork().to(device)  
print(model)  

二、定义了一个具有三个卷积层的CNN,每个卷积层后面都跟着ReLU激活函数,前两个卷积层后面还跟着最大池化层。最后,通过一个全连接层将卷积层的输出转换为10个类别的得分。

import torch  
import torch.nn as nn  class CNN(nn.Module):  def __init__(self):  super(CNN, self).__init__()  self.conv1 = nn.Sequential(  nn.Conv2d(in_channels=1, out_channels=16, kernel_size=5, stride=1, padding=2),  nn.ReLU(),  nn.MaxPool2d(kernel_size=2),  )  self.conv2 = nn.Sequential(  nn.Conv2d(16, 32, 5, 1, 2),  nn.ReLU(),  nn.Conv2d(32, 32, 5, 1, 2),  nn.ReLU(),  nn.MaxPool2d(2),  )  self.conv3 = nn.Sequential(  nn.Conv2d(32, 64, 5, 1, 2),  nn.ReLU(),  )  self.out = nn.Linear(64 * 7 * 7, 10)  # 确保这里的输入特征数与卷积层输出后的特征数相匹配  def forward(self, x):  x = self.conv1(x)  x = self.conv2(x)  x = self.conv3(x)  # 输出应为(batch_size, 64, 7, 7)  x = x.view(x.size(0), -1)  # 展平操作,输出为(batch_size, 64*7*7)  output = self.out(x)  return output  model = CNN().to(device)  
print(model)
  • in_channels=1:这指定了输入图像的通道数。

  • out_channels=16:这指定了卷积操作后输出的通道数,也就是卷积核(或称为滤波器)的数量。

  • kernel_size=5:这定义了卷积核的大小。

  • stride=1:这指定了卷积核在输入数据上滑动的步长。

  • padding=2:这定义了要在输入数据周围添加的零填充(zero-padding)的数量。

四、处理数据集和测试集

训练集处理:

def train(dataloader, model, loss_fn, optimizer):  model.train()  # 将模型设置为训练模式  batch_size_num = 1  # 这不是标准的用法,但在这里用作计数已处理批次的数量  for x, y in dataloader:  # 遍历数据加载器中的每个批次  x, y = x.to(device), y.to(device)  # 将数据和标签移动到指定的设备(如GPU)  pred = model(x)  # 通过模型进行前向传播  loss = loss_fn(pred, y)  # 计算预测和真实标签之间的损失  optimizer.zero_grad()  # 清除之前的梯度  loss.backward()  # 反向传播,计算当前梯度  optimizer.step()  # 更新模型的权重  loss_value = loss.item()if batch_size_num % 200 == 0:print(f"{loss_value:>7f}[number:{batch_size_num}]")#打印结果batch_size_num += 1  # 增加已处理批次的数量

测试集处理:

def test(dataloader, model, loss_fn):size = len(dataloader.dataset)num_batches = len(dataloader)model.eval()test_loss, correct = 0, 0with torch.no_grad():for x, y in dataloader:x, y = x.to(device), y.to(device)pred = model(x)test_loss += loss_fn(pred, y).item()correct += (pred.argmax(1) == y).type(torch.float).sum().item()a = (pred.argmax(1) == y)b = (pred.argmax(1) == y).type(torch.float)test_loss /= num_batchescorrect /= sizeprint(f'Test result: \n Accuracy: {(100 * correct)}%, Avg loss: {test_loss}')

模型训练:

loss_fn = nn.CrossEntropyLoss()optimizer = torch.optim.Adam(model.parameters(), lr=0.001)epochs = 10
for t in range(epochs):print(f"-----------------------------------------------\nepcho{t+1}")train(train_dataloader, model, loss_fn, optimizer)
print("Done!")
test(test_dataloader, model, loss_fn)train(train_dataloader,model,loss_fn,optimizer)
test(test_dataloader,model, loss_fn)

结果:

神经网络:

cnn:


http://www.ppmy.cn/embedded/116007.html

相关文章

常用的k8s容器网络模式有哪些?

常用的k8s容器网络模式包括Bridge模式、Host模式、Overlay模式、Flannel模式、CNI(ContainerNetworkInterface)模式。K8s的容器网络模式多种多样,每种模式都有其特点和适用场景。Bridge模式适用于简单的容器通信场景;Host模式适用…

手机上轻松解压并处理 JSON 文件

JSON(JavaScript Object Notation)是一种轻量级的数据交换格式,在手机上有着广泛的应用场景。 首先,在数据传输方面,许多移动应用程序通过网络请求与后端服务器进行交互,而服务器端的 API 接口通常使用 JS…

Linux-网络编程

1. 初始网络协议 “协议” 是一种约定. 打电话约定电话铃响的次数的约定 协议分层 协议本质也是软件, 在设计上为了更好的进行模块化, 解耦合, 也是被设计成为层状结构的 1.1 OSI 七层模型 OSI(Open System Interconnection&#…

【电商搜索】现代工业级电商搜索技术-Ha3搜索引擎平台简介

【电商搜索】现代工业级电商搜索技术-Ha3搜索引擎平台简介 — 初稿V1.0 Ha3搜索引擎平台详细介绍 在当今的互联网时代,搜索引擎扮演着至关重要的角色,尤其是在电子商务领域。Ha3搜索引擎平台是由阿里巴巴搜索团队开发的一个先进的搜索引擎&#xff0c…

scss知识汇总

参考资料 https://www.bilibili.com/video/BV1KJ411Y7Zz?p11 //入门 https://www.bilibili.com/video/BV1bK411H7YU?fromsearch&seid1507236772512004325 //精简 https://www.bilibili.com/video/BV1KE411b7RQ?p25 //大全h…

【bug】通过lora方式微调sdxl inpainting踩坑

报错内容 ValueError: Attempting to unscale FP16 gradients. 报错位置 if accelerator.sync_gradients:params_to_clip (itertools.chain(unet_lora_parameters, text_lora_parameters_one, text_lora_parameters_two)if args.train_text_encoderelse unet_lora_parameters…

热斑黄斑光伏发电板 红外黄斑检测图像数据集内含最高温度信息 1200张,jpg格式。

热斑黄斑光伏发电板 红外黄斑检测图像数据集 内含最高温度信息 1200张,jpg格式。 热斑黄斑光伏发电板红外黄斑检测图像数据集介绍 数据集名称 热斑黄斑光伏发电板红外黄斑检测图像数据集(Hot Spot and Yellow Spot Detection in Photovoltaic Panels I…

golang调用163邮箱发送邮件

一、导入依赖 go get gopkg.in/gomail.v2 go get github.com/spf13/viper二、发送邮件的方法 注:所有配置均写在了配置文件当中,此处用viper调用 // 定义发送邮件的功能方法 func sendMail(SendFileName string) error {// 此处是邮件的正文message : …