利用卷积神经网络进行手写数字的识别

server/2024/12/19 16:19:21/

数据集介绍

MNIST(Modified National Institute of Standards and Technology)数据集是一个广泛使用的手写数字识别数据集,常用于机器学习和计算机视觉领域中的分类任务。它包含了从0到9的手写数字样本,常用于训练和测试各种图像分类算法。

数据集概况

MNIST数据集由60,000个训练样本和10,000个测试样本组成,每个样本是一张28×28像素的灰度图像,表示一个手写数字。每个图像是一个二维矩阵,像素值范围从0(黑色)到255(白色),灰度值表示不同的颜色深度。数据集中的标签是这些图像对应的数字(0-9)。

数据集格式

  • 训练集:60,000个图像,每个图像有一个对应的标签(0到9之间的数字)。
  • 测试集:10,000个图像,也有对应的标签。

使用场景

  1. 图像分类任务:由于数据集较小且标准化,MNIST是机器学习算法(尤其是深度学习模型)测试和比较性能的一个标准数据集。
  2. 模型性能评估:MNIST被广泛用于评估各种机器学习模型的效果,尤其是在图像处理领域。
  3. 教学:由于其简单性,MNIST常作为入门学习机器学习和神经网络的教学材料。

特点

  • 图像尺寸固定:28×28像素,适合用作标准输入。
  • 图像内容简单:大多数手写数字都是规范且易于分辨的。
  • 数据集较小,适合于快速实验和初步的模型验证。

数据集获取

MNIST数据集可以通过多个平台获取,例如:

  • 通过TensorFlow、PyTorch等框架的内建API加载。
  • 从MNIST官网下载。

数据预处理及参数选择

数据处理

transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
# softmax归一化指数函数(https://blog.csdn.net/lz_peter/article/details/84574716),其中0.1307是mean均值和0.3081是std标准差train_dataset = datasets.MNIST(root='./data/mnist', train=True, transform=transform)  # 本地没有就加上download=True
test_dataset = datasets.MNIST(root='./data/mnist', train=False, transform=transform)  # train=True训练集,=False测试集
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

参数的选择

batch_size = 64                #每个批次大小中有64个样本
learning_rate = 0.01           #学习率
momentum = 0.5                 #梯度下降冲量
epochs = 10                    #训练轮数
  • batch_size = 64:每次训练时使用64个样本来计算梯度并更新权重。
  • learning_rate = 0.01:每次权重更新时,步长为0.01,影响训练速度和稳定性。
  • momentum = 0.5:通过加权平均过去的梯度,帮助加速收敛并减少梯度更新的震荡。
  • epochs = 10:模型将在训练数据上进行10次完整的迭代,通常可以在这个范围内找到适合的训练状态。

网络模型

class Net(torch.nn.Module):def __init__(self):super(Net, self).__init__()self.conv1 = torch.nn.Sequential(torch.nn.Conv2d(1, 10, kernel_size=5),torch.nn.ReLU(),torch.nn.MaxPool2d(kernel_size=2),)self.conv2 = torch.nn.Sequential(torch.nn.Conv2d(10, 20, kernel_size=5),torch.nn.ReLU(),torch.nn.MaxPool2d(kernel_size=2),)self.conv3 = torch.nn.Sequential(torch.nn.Flatten(),torch.nn.Linear(320, 50),torch.nn.Linear(50, 10),)def forward(self, x):batch_size = x.size(0)x = self.conv1(x)  # 一层卷积层,一层池化层,一层激活层(图是先卷积后激活再池化,差别不大)x = self.conv2(x)  # 再来一次x = self.conv3(x)return x  # 最后输出的是维度为10的,也就是(对应数学符号的0~9)
  • 输入层:

    • 输入尺寸:每张输入图像是28x28像素的灰度图,单通道。输入张量的形状为 (batch_size, 1, 28, 28),其中 batch_size 是一次处理的图像数量,1 是表示单通道的灰度图像,28x28 是图像的尺寸。
  • 第一层卷积层(conv1):

    • 卷积层:使用一个大小为 5x5 的卷积核,将输入图像的1个通道(灰度)转换为10个通道。卷积核的步幅为1,填充为0(即没有边缘扩展)。这会产生一个大小为 24x24 的特征图(由于没有填充,尺寸会减少)。
    • 激活函数:ReLU(Rectified Linear Unit),它会对每个像素值进行非线性转换(ReLU(x) = max(0, x)),有效地引入了非线性特性。
    • 池化层:最大池化层使用 2x2 的池化窗口和步幅为2。池化操作减少了特征图的尺寸,将每个 2x2 的区域映射为最大值。池化操作将图像尺寸减半,从 24x24 减小为 12x12,同时减少计算量。
  • 第二层卷积层(conv2):

    • 卷积层:卷积核的大小为 5x5,将前一层输出的10个通道转换为20个通道。同样,步幅为1,没有填充。这个操作将特征图的大小从 12x12 减少到 8x8
    • 激活函数:使用ReLU激活函数。
    • 池化层:再次使用最大池化,池化窗口为 2x2,步幅为2。此操作将尺寸从 8x8 减小为 4x4
  • 全连接层(conv3):

    • 展平操作(Flatten):经过两层卷积和池化操作后,输出特征图的大小为 20x4x4。在传入全连接层之前,需要将这个多维的张量展平成一维向量。展平后的尺寸是 320(即 20 * 4 * 4)。
    • 第一个全连接层:将展平后的320个元素映射到50个神经元。该层的作用是通过加权和偏置的线性变换对输入进行处理,并通过激活函数进行非线性转换。
    • 第二个全连接层:将50个神经元映射到10个神经元,输出的每个神经元代表一个数字类别(0到9)。
  • 输出层:

    • 输出尺寸:最终输出为一个10维的向量,其中每个值表示输入图像属于每个类别的“分数”。这个分数可以通过softmax层转化为概率,用于多类分类任务。

模型训练

# Construct loss and optimizer ------------------------------------------------------------------------------
loss_f = torch.nn.CrossEntropyLoss()  # 交叉熵损失
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate, momentum=momentum)  # lr学习率,momentum冲量# Train and Test CLASS --------------------------------------------------------------------------------------
# 把单独的一轮一环封装在函数类里
def train(epoch):running_loss = 0.0  # 这整个epoch的loss清零running_total = 0running_correct = 0for batch_idx, data in enumerate(train_loader, 0):  #第一个代表训练的批次,data中包括数据和标签,第一个数据代表输入即inputs,第二个数据代表标签labelsinputs, target = dataoptimizer.zero_grad()   #将之前的梯度清零# forward + backward + updateoutputs = model(inputs)loss = loss_f(outputs, target)#反向传播loss.backward()#参数更新optimizer.step()# 把运行中的loss累加起来,为了下面300次一除running_loss += loss.item()# 把运行中的准确率acc算出来_, predicted = torch.max(outputs.data, dim=1)running_total += inputs.shape[0]running_correct += (predicted == target).sum().item()if batch_idx % 300 == 299:  # 不想要每一次都出loss,浪费时间,选择每300次出一个平均损失,和准确率print('[%d, %5d]: loss: %.3f , acc: %.2f %%'% (epoch + 1, batch_idx + 1, running_loss / 300, 100 * running_correct / running_total))running_loss = 0.0  # 这小批300的loss清零running_total = 0running_correct = 0  ## 这小批300的acc清零# torch.save(model.state_dict(), './model_Mnist.pth')# torch.save(optimizer.state_dict(), './optimizer_Mnist.pth')def test():correct = 0total = 0with torch.no_grad():  # 测试集不用算梯度for data in test_loader:images, labels = dataoutputs = model(images)_, predicted = torch.max(outputs.data, dim=1)  # dim = 1 列是第0个维度,行是第1个维度,沿着行(第1个维度)去找1.最大值和2.最大值的下标total += labels.size(0)  # 张量之间的比较运算correct += (predicted == labels).sum().item()acc = correct / totalprint('[%d / %d]: Accuracy on test set: %.1f %% ' % (epoch + 1, epochs, 100 * acc))  # 求测试的准确率,正确数/总数return acc# Start train and Test --------------------------------------------------------------------------------------
if __name__ == '__main__':acc_list_test = []for epoch in range(epochs):train(epoch)# if epoch % 10 == 9:  #每训练10轮 测试1次acc_test = test()acc_list_test.append(acc_test)plt.plot(acc_list_test)plt.xlabel('Epoch')plt.ylabel('Accuracy On TestSet')plt.show()

训练结果


http://www.ppmy.cn/server/151479.html

相关文章

2024年贵州技能大赛暨全省第二届数字技术应用职业技能竞赛“信息通信网络运行管理员”赛题

1.分析attack.pcapng数据包文件,通过分析数据包 attack.pcapng找出恶意用户第一次访问HTTP服务的数据包是第 几号,将该号数作为Flag值提交; flag{277} 2.继续查看数据包文件attack.pcapng,分析出恶意用户扫 描了哪些端口&#xf…

深度学习中的激活函数

激活函数(activation function)是应用于网络中各个神经元输出的简单变换,为其引入非线性属性,使网络能够对更复杂的数据进行建模,使其能够学习更复杂的模式。如果没有激活函数,神经元只会对输入进行枯燥的线性数学运算。这意味着&…

视频生成缩略图

文章目录 视频生成缩略图使用ffmpeg 视频生成缩略图 最近有个需求&#xff0c;视频上传之后在列表和详情页需要展示缩略图 使用ffmpeg 首先引入jar包 <dependency><groupId>org.bytedeco</groupId><artifactId>javacpp</artifactId><vers…

《向量数据库指南》——Milvus Cloud 2.5:Sparse-BM25引领全文检索新时代

Milvus Cloud BM25:重塑全文检索的未来 在最新的Milvus Cloud 2.5版本中,我们自豪地引入了“全新”的全文检索能力,这一创新不仅巩固了Milvus Cloud在向量数据库领域的领先地位,更为用户提供了前所未有的灵活性和效率。作为大禹智库的向量数据库高级研究员,以及《向量数据…

低延迟!实时处理!中软高科AI边缘服务器,解决边缘计算多样化需求!

根据相关统计&#xff0c;随着物联网的发展和5G技术的普及&#xff0c;到2025年&#xff0c;全球物联网设备连接数将达到1000亿&#xff0c;海量的计算数据使得传输到云端再处理的云计算方式显得更捉襟见肘。拥有低延迟、实时处理、可扩展性和更高安全性的边缘计算应运而生&…

使用k6进行kafka负载测试

1.安装环境 kafka环境 参考Docker搭建kafka环境-CSDN博客 xk6-kafka环境 ./xk6 build --with github.com/mostafa/xk6-kafkalatest 查看安装情况 2.编写脚本 test_kafka.js // Either import the module object import * as kafka from "k6/x/kafka";// Or in…

【ubuntu18.04】ubuntu18.04挂在硬盘出现 Wrong diagnostic page; asked for 1 got 8解决方案

错误日志 [ 8754.700227] usb 2-3: new full-speed USB device number 3 using xhci_hcd [ 8754.867389] usb 2-3: New USB device found, idVendor0e0f, idProduct0002, bcdDevice 1.00 [ 8754.867421] usb 2-3: New USB device strings: Mfr1, Product2, SerialNumber0 [ 87…

Pikachu-XXE靶场(注入攻击)

1.攻击测试 <?xml version"1.0"?> <!DOCTYPE foo [ <!ENTITY xxe "a" > ]> <foo>&xxe;</foo> 2.查看文件 <?xml version"1.0"?> <!DOCTYPE foo [ <!ENTITY xxe SYSTEM "file:///E:/ph…