【基于深度学习的验证码识别】---- part3数据加载、模型等API介绍(1)

embedded/2025/3/20 18:48:23/

一、MNIST数据集

MNIST(Modified National Institute of Standards and Technology)数据集是计算机视觉和机器学习领域最经典的入门级数据集之一,主要用于手写数字识别任务。

使用示例(以PyTorch为例)

from torchvision.datasets import MNIST
mnist_train = MNIST(root='./MNIST_data', train=True, download=True)

在这里插入图片描述

from torchvision.datasets import MNIST
import matplotlib.pyplot as plt
mnist_train = MNIST(root='./MNIST_data', train=True, download=True)# 训练集长度
print(len(mnist_train))
# 取第一个图片
print(mnist_train[0])
image = mnist_train[5000][0]
# 打印出图片
plt.imshow(image)
plt.show()
print(mnist_train[5000][1])

二、数据加载

在PyTorch中,使用DataLoader加载MNIST数据集时,参数的合理配置直接影响训练效率和模型性能。以下是核心参数的详细说明及其在MNIST场景中的应用:

from torch.utils.data import DataLoader

参数:batch_size、shuffle、num_workers、pin_memory、drop_last

1、batch_size(批次大小)
  • 定义:每个批次包含的样本数量。例如,batch_size=64表示每次迭代加载64张图像。
  • 作用:定义每个批次包含的样本数量。例如,若batch_size=64,则每次迭代从数据集中加载64张手写数字图像。
  • MNIST应用
    MNIST图像尺寸为28x28,单个样本数据量小,通常可设置较大的batch_size(如64或128)以充分利用显存并加速训练。
    显存不足时需减小batch_size,否则会引发内存错误(OOM)
2、 shuffle(数据打乱)
  • 定义:是否在每个训练周期(epoch)开始时随机打乱数据顺序。
  • 作用
    • 防止模型偏见:避免模型学习到数据顺序特征(如MNIST训练集需设为True)。
    • 测试集处理:测试时通常设为False以保持评估结果一致性。
  • MNIST应用
    # 训练集打乱,测试集不打乱
    train_loader = DataLoader(..., shuffle=True)
    test_loader = DataLoader(..., shuffle=False)
    
3、 num_workers(子进程数)
  • 定义:用于并行加载数据的子进程数量。默认为0(主进程加载)。
  • 作用
    • 加速数据加载:多进程并行读取数据(建议设为CPU核心数的2~4倍,如4或8)。
    • 资源平衡:MNIST数据量小,过高值可能导致内存溢出(需实验调优)。
  • MNIST应用
    # 使用4个子进程加载数据
    train_loader = DataLoader(..., num_workers=4)
    
4、pin_memory(内存锁定)
  • 定义:是否将数据复制到CUDA固定内存(pinned memory)。
  • 作用
    • 加速GPU传输:启用后,数据从CPU到GPU的传输速度更快(GPU训练时强烈建议设为True)。
    • 资源占用:仅对GPU有效,CPU训练时可忽略。
  • MNIST应用
  # GPU训练时启用内存锁定train_loader = DataLoader(..., pin_memory=True)
5、 drop_last(丢弃末批)
  • 定义:当数据集大小无法被batch_size整除时,是否丢弃最后一个不完整批次。
  • 作用
    • 避免小批次影响:丢弃末尾样本(如MNIST训练集60000样本,batch_size=64时最后一个批次含16样本)。
    • 分布式训练对齐:需所有批次大小一致时启用。
  • MNIST应用
    # 丢弃不完整批次
    train_loader = DataLoader(..., batch_size=64, drop_last=True)
    

代码示例

from torch.utils.data import DataLoader
from torchvision import datasets, transforms# 定义数据转换
transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5,), (0.5,))
])# 加载 MNIST 数据集
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)# 创建 DataLoader
train_loader = DataLoader(dataset=train_dataset,batch_size=64,shuffle=True,num_workers=4,pin_memory=True,drop_last=True
)

三、图片处理 transform

深度学习中,图像数据通常需要进行预处理(如缩放、裁剪、归一化等)以适应模型的输入要求。PyTorch 提供了 torchvision.transforms 模块,用于定义和实现这些图像处理操作。

transforms 的作用

transforms 是一个用于图像预处理的工具集,可以将一系列图像处理操作组合在一起,形成一个处理流水线(pipeline)。这些操作通常包括:

  • 数据增强:增加数据的多样性,防止模型过拟合。
  • 数据标准化:将数据转换为模型所需的格式(如归一化到特定范围)。
  • 数据转换:将图像转换为张量(Tensor)格式,以便输入模型。
常用 transforms 操作
1、基础操作
  • Resize: 调整图像大小。
transforms.Resize((height, width))  # 将图像调整为指定大小
  • CenterCrop: 从图像中心裁剪指定大小的区域。
transforms.CenterCrop(size)  # 裁剪大小为 (size, size)
  • RandomCrop: 随机裁剪图像。
transforms.RandomCrop(size)  # 随机裁剪大小为 (size, size)
  • RandomHorizontalFlip: 随机水平翻转图像。
transforms.RandomHorizontalFlip(p=0.5)  # 以 50% 的概率水平翻转
  • RandomRotation: 随机旋转图像。
transforms.RandomRotation(degrees=30)  # 随机旋转 ±30 度
2、 颜色变换
  • ColorJitter: 随机改变图像的亮度、对比度、饱和度和色调。
transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1)
  • Grayscale: 将图像转换为灰度图。
transforms.Grayscale(num_output_channels=1)  # 转换为单通道灰度图
3、 归一化和标准化
  • ToTensor: 将图像(PIL 或 NumPy 格式)转换为 PyTorch 张量(Tensor),并将像素值从 [0, 255] 缩放到 [0, 1]。
transforms.ToTensor()

在使用 transforms.ToTensor() 处理图像后,PyTorch 会将图像的通道维度移动到最前面。
transforms.ToTensor() 的作用
1.将图像转换为张量:
输入的图像通常是 PIL 图像或 NumPy 数组,形状为 (H, W, C),其中:
H 是图像的高度(Height)。
W 是图像的宽度(Width)。
C 是图像的通道数(Channels,例如 RGB 图像为 3,灰度图像为 1)。
transforms.ToTensor() 会将图像转换为 PyTorch 张量(Tensor),并将像素值从 [0, 255] 缩放到 [0, 1]。

2通道维度的变化:
转换后的张量形状为 (C, H, W),即通道维度被移动到最前面。
这种格式是 PyTorch 的标准输入格式,便于后续的模型处理。

  • Normalize: 对图像进行标准化处理(减去均值,除以标准差)。
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])

这里的均值和标准差通常是根据数据集计算的(例如 ImageNet 的均值和标准差)。

4、 组合操作
  • Compose: 将多个操作组合成一个流水线。
 transforms.Compose([transforms.Resize((224, 224)),transforms.RandomHorizontalFlip(),transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
示例代码
from torchvision import datasets, transforms# 定义 transforms 流水线
transform = transforms.Compose([transforms.Resize((32, 32)),          # 调整图像大小为 32x32transforms.RandomHorizontalFlip(),    # 随机水平翻转transforms.ToTensor(),                # 转换为张量,并缩放到 [0, 1]transforms.Normalize((0.5,), (0.5,))  # 归一化到 [-1, 1]
])# 加载 MNIST 数据集并应用 transforms
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform
)# 查看处理后的图像
image, label = train_dataset[0]
print(image.shape)  # 输出: torch.Size([1, 32, 32])
总结
操作作用
Resize调整图像大小。
CenterCrop从图像中心裁剪指定大小的区域。
RandomCrop随机裁剪图像。
RandomHorizontalFlip随机水平翻转图像。
RandomRotation随机旋转图像。
ColorJitter随机改变图像的亮度、对比度、饱和度和色调。
Grayscale将图像转换为灰度图。
ToTensor将图像转换为张量,并缩放到 [0, 1]。
Normalize对图像进行标准化处理(减去均值,除以标准差)。
Compose将多个操作组合成一个流水线。
如何在数据加载过程中看到图片的样子

先轴交换,再利用make_grid合并再处理成数组.numpy()后,就可以展示出来

from torchvision.datasets import MNIST
import matplotlib.pyplot as plt
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torchvision.utils import make_gridmy_transforms = transforms.Compose([transforms.PILToTensor(),]
)
mnist_train = MNIST(root='./MNIST_data', train=True, download=True, transform=transforms.PILToTensor())
dataloader = DataLoader(mnist_train, batch_size=5, shuffle=True) #DataLoader 初始化
for (image, label) in dataloader:# 遍历 DataLoaderprint(image.shape) #torch.Size([5, 1, 28, 28])print(label) #tensor([3, 1, 2, 8, 3])print(make_grid(image).shape)  #torch.Size([3, 32, 152])  使用 make_grid 将图像拼接成网格image = make_grid(image).permute(1,2,0).numpy()#调整网格图像的维度并转换为 NumPy 数组plt.imshow(image) #使用 Matplotlib 显示图像plt.show()exit()

http://www.ppmy.cn/embedded/174215.html

相关文章

第七章 排序算法法法

算法时间复杂度 衡量一个算法的时间复杂度 度量一个程序(算法)执行时间的两种方法 事后统计法 这种方法可行,但是有两个问题:一是要想对涉及的算法的运行性能进行评测,需要实际运行该程序;二是所得时间的统计量依赖于计算机的硬件,软件等环境因素,这种方式,要在同一台计算机的…

【设计模式有哪些】

一、创建型模式(Creation Patterns) 1. 单例模式(Singleton) 核心思想:保证一个类仅有一个实例,并提供全局访问点。实现方式:public class Singleton {// 1. 私有静态实例,volatil…

【css酷炫效果】纯CSS实现悬浮弹性按钮

【css酷炫效果】纯CSS实现悬浮弹性按钮 缘创作背景html结构css样式完整代码效果图 想直接拿走的老板,链接放在这里:https://download.csdn.net/download/u011561335/90492020 缘 创作随缘,不定时更新。 创作背景 刚看到csdn出活动了&…

用 C 语言理解封装、继承、多态

前言 个人邮箱:zhangyixu02gmail.com本文主要是给一些做嵌入式软件开发,并且非计科的朋友做科普。使用 C 语言理解封装、继承、多态等概念。 正文 基类:最基础的结构体或函数。派生类:基类的继承自己的特性。封装:将…

桥接模式详解

以下是一个结合桥接模式解决实际开发问题的Java实现案例,涵盖多维度扩展、平台兼容性处理、渲染引擎解耦等场景需求,附带逐行中文注释: 场景描述 开发一个跨平台图形渲染框架,需支持: 图形类型扩展:圆形、…

iOS底层原理系列02-深入了解Objective-C

1. Objective-C的本质 用Objective-C编写的代码,底层其实都是C\C代码 所以Objective-C面向对象都是基于 C\C的数据结构(结构体)实现的。 Objective-C并非像其他语言那样在编译期完全确定程序的行为,而是将许多决策推迟到运行时进行,这种特性…

基于FPGA的DDS连续FFT 仿真验证

基于FPGA的 DDS连续FFT 仿真验证 1 摘要 本文聚焦 AMD LogiCORE IP Fast Fourier Transform (FFT) 核心,深入剖析其在 FPGA 设计中的应用。该 FFT 核心基于 Cooley - Tukey 算法,具备丰富特性,如支持多种数据精度、算术类型及灵活的运行时配置。文中详细介绍了其架构选项、…

JMeter基本介绍

Apache JMeter 工具详解 一、JMeter 简介 JMeter 是 Apache 基金会开源的 Java 应用程序,主要用于 性能测试、负载测试 和 功能测试。它通过对服务器或网络资源模拟多种负载条件(如并发用户、持续压力),帮助评估系统性能指标&am…