基于多层感知机(MLP)实现MNIST手写体识别

embedded/2025/2/28 9:47:11/

实现步骤

  1. 下载数据集
  2. 处理好数据集
  3. 确定好模型(初始化模型参数等等)
  4. 确定优化函数(损失函数也称为目标函数)和优化方法(一般选用随机梯度下降 SDG )
  5. 进行模型的训练
  6. 进行模型的评估
import torch
import torchvision
from torch import nn
from torch import optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import matplotlib.pyplot as plt# 1. 下载数据集
mnist_train = torchvision.datasets.MNIST(root='../data', train=True, transform=transforms.ToTensor(), download=True)
mnist_test = torchvision.datasets.MNIST(root='../data', train=False, transform=transforms.ToTensor(), download=True)# 2. 创建批量数据迭代器
train_iter = DataLoader(mnist_train, batch_size=256, shuffle=True)
test_iter = DataLoader(mnist_test, batch_size=256)# 3. 可视化检查数据
var = next(iter(train_iter))
plt.title(str(var[1][0]))  # 显示标签
plt.imshow(var[0][0].squeeze().numpy(), cmap='gray')  # 显示图片
plt.show()# 4. 定义模型:多层感知机
net = nn.Sequential(nn.Flatten(),nn.Linear(28 * 28, 512),nn.ReLU(),nn.Linear(512, 256),nn.ReLU(),nn.Linear(256, 10) # 注意这里是不需要加 Softmax 了的,因为后面定义了,nn.CrossEntropyLoss()这个会自动帮我们进行 Softmax 以及进行损失计算。其实就是目标函数
)# 初始化模型参数
def init_weights(m):if isinstance(m, nn.Linear):nn.init.normal_(m.weight, std=0.01)net.apply(init_weights)# 5. 定义损失函数和优化器
loss_fn = nn.CrossEntropyLoss()  # CrossEntropyLoss已经包含了softmax,所以不需要LogSoftmax
optimizer = optim.SGD(net.parameters(), lr=0.2)# 6. 训练模型
epoch_num = 20
for epoch in range(epoch_num):net.train()  # 设置为训练模式total_loss = 0for X, y in train_iter:optimizer.zero_grad()  # 清除梯度y_hat = net(X)  # 前向传播loss = loss_fn(y_hat, y)  # 计算损失loss.backward()  # 反向传播optimizer.step()  # 更新参数total_loss += loss.item() * X.shape[0]  # 累积损失avg_loss = total_loss / len(mnist_train)  # 计算平均损失print(f'Epoch {epoch + 1}/{epoch_num}, Loss: {avg_loss:.4f}')# 7. 评估模型
def evaluate_model(net, test_iter):net.eval()  # 设置为评估模式correct, total = 0, 0with torch.no_grad():  # 在评估时不需要计算梯度for X, y in test_iter:y_hat = net(X)_, predicted = torch.max(y_hat, 1)  # 获取预测的标签correct += (predicted == y).sum().item()  # 统计正确的个数total += y.size(0)  # 统计总数accuracy = correct / totalprint(f'Accuracy on test set: {accuracy * 100:.2f}%')# 评估模型的表现
evaluate_model(net, test_iter)

代码实践的结果:

  1. 自己不会去计算损失
  2. 在模型进行训练的时候不知道如何把数据放进去:原来只需要创建好了 DataLoader 以后,通过枚举
    就可以拿到数据了。
  3. 最后进行模型评估的时候也是用 AI 进行完成了。所以多少还是差点意思。
  4. 后面的代码多去实践实践,并且思考吧!!!

关于代码中交叉熵计算的理解

理解损失函数(loss_fn)是如何计算的,对于训练神经网络来说是非常重要的。具体到你提到的这行代码:

loss = loss_fn(y_hat, y)  # 计算损失

损失函数的定义:

在你的代码中,损失函数是:

loss_fn = nn.CrossEntropyLoss()

nn.CrossEntropyLoss() 是一种常用于多分类问题的损失函数,它实际上包含了两个步骤:

  1. Softmax:将模型的输出转换为概率分布。
  2. 交叉熵损失:计算真实标签与预测概率分布之间的差距。

为什么要用交叉熵呢?因为交叉熵可以来衡量预测差距,这个我们只需要这个知识点,并且知道上面的公式就好了。

我们逐步分析这两个步骤。

1. Softmax(概率转换)

假设模型的输出 y_hat 是一个向量,其中每个元素代表对应类别的“分数”(或者说是原始的 logits)。例如,假设有 3 个类别,模型的输出可能是:

y_hat = [2.0, 1.0, -1.0]  # 这三个数字是 logits,不是概率

通过 Softmax 函数,我们将这些 logits 转换成概率:

# 计算 softmax
softmax = torch.nn.functional.softmax(y_hat, dim=-1)

softmax 的输出会是一个概率分布,每个数值的范围在 [0, 1] 之间,且所有数值加起来为 1。例如,经过 Softmax 后可能得到:

softmax = [0.7, 0.2, 0.1]  # 类别 0 的概率是 0.7,类别 1 的概率是 0.2,类别 2 的概率是 0.1

2. 交叉熵损失(Cross Entropy Loss)

交叉熵是衡量两个概率分布之间差异的一个标准方法。在分类任务中,我们希望预测的类别概率与真实标签分布尽可能接近。

对于一个单一的样本,交叉熵损失的计算公式为:

L = − ∑ i = 1 C y i log ⁡ ( p i ) L = - \sum_{i=1}^{C} y_i \log(p_i) L=i=1Cyilog(pi)

  • ( C ) 是类别数。
  • ( y_i ) 是真实标签(在 one-hot 编码下,真实类别的标签为 1,其他类别为 0)。
  • ( p_i ) 是模型预测的概率。

对于多分类任务来说,交叉熵损失会选择对应真实标签的类别概率 ( p_{\text{true}} ) 来计算损失。例如,如果真实标签是类别 0,那么我们只关注模型在类别 0 上的预测概率。

假设真实标签 y 是类别 0,对应的 one-hot 编码是 [1, 0, 0],而模型的预测是:

softmax = [0.7, 0.2, 0.1]

那么交叉熵损失为:

L = − ( 1 ⋅ log ⁡ ( 0.7 ) + 0 ⋅ log ⁡ ( 0.2 ) + 0 ⋅ log ⁡ ( 0.1 ) ) = − log ⁡ ( 0.7 ) ≈ 0.3567 L = - (1 \cdot \log(0.7) + 0 \cdot \log(0.2) + 0 \cdot \log(0.1)) = - \log(0.7) \approx 0.3567 L=(1log(0.7)+0log(0.2)+0log(0.1))=log(0.7)0.3567

nn.CrossEntropyLoss() 如何工作

在 PyTorch 中,nn.CrossEntropyLoss 会自动处理上述两个步骤:

  1. y_hat(logits)转换为概率。
  2. 使用真实标签 y 计算交叉熵损失。
输入和输出:
  • y_hat: 这是模型的原始输出(logits),形状为 (batch_size, num_classes)。每一行是一个样本的 logits。
  • y: 这是标签,通常是一个包含类别索引的向量,形状为 (batch_size,)。每个元素是该样本的真实类别索引。

例如:

假设我们有以下数据:

  • 模型的输出(logits)为:

    y_hat = torch.tensor([[2.0, 1.0, -1.0],  # 第一个样本[0.5, 1.5, 0.3]]) # 第二个样本
    
  • 真实标签 y 为:

    y = torch.tensor([0, 1])  # 第一个样本的标签是类别 0,第二个样本的标签是类别 1
    

使用 nn.CrossEntropyLoss() 计算损失:

loss_fn = nn.CrossEntropyLoss()
loss = loss_fn(y_hat, y)

CrossEntropyLoss 会首先对 y_hat 进行 softmax 转换,然后计算每个样本的交叉熵损失。你可以通过打印出来的 loss 来查看模型的表现。

总结:

  • y_hat 是模型的原始输出(logits),表示每个类别的“分数”。
  • nn.CrossEntropyLoss 会自动处理 softmax 和交叉熵损失的计算。
  • 损失函数的目的是衡量模型的输出与真实标签之间的差异,差异越小,损失值越小,说明模型的预测越准确。

使用`nn.CrossEntropyLoss 会自动进行独热编码

在计算交叉熵损失时,nn.CrossEntropyLoss 会自动处理标签,并且不需要你手动将标签转换为独热编码(one-hot encoding)。

具体来说:

  • y_hat:是模型的原始输出(logits),形状为 (batch_size, num_classes),每一行是一个样本的预测结果,包含每个类别的分数(logits)。
  • y:是标签,形状为 (batch_size,),每个元素是该样本的真实类别的 索引,而不是独热编码。

nn.CrossEntropyLoss 会自动使用标签 y 中的类别索引(如类别 0, 1, 2)来计算损失,它会根据该类别索引选择对应的模型输出进行计算,而不需要你事先将标签转换为独热编码。

举个例子:

假设我们有一个批次的两个样本,模型的输出 y_hat 和真实标签 y 如下:

模型的输出 y_hat(logits):
y_hat = torch.tensor([[2.0, 1.0, -1.0],  # 第一个样本的 logits[0.5, 1.5, 0.3]]) # 第二个样本的 logits
真实标签 y(类别索引):
y = torch.tensor([0, 1])  # 第一个样本的标签是类别 0,第二个样本的标签是类别 1

在这个例子中,y_hat 的形状是 (2, 3),表示有两个样本,每个样本有三个类别的 logits。

  • 对于第一个样本,它的真实标签是类别 0y[0] = 0
  • 对于第二个样本,它的真实标签是类别 1y[1] = 1

当使用 nn.CrossEntropyLoss 时,它会根据真实标签中的类别索引来选择对应的类别概率(通过 softmax 后的结果),然后计算交叉熵损失。PyTorch 会自动完成:

  1. Softmax 转换:将 y_hat 的 logits 转换为概率分布。
  2. 损失计算:根据真实标签的类别索引计算交叉熵损失。

损失计算过程:

假设 y_hat 的 softmax 输出(经过 softmax 转换后)为:

softmax = torch.nn.functional.softmax(y_hat, dim=1)# softmax 结果:
# [[0.7, 0.2, 0.1],  # 第一个样本的概率
#  [0.3, 0.6, 0.1]]  # 第二个样本的概率

然后计算交叉熵损失时:

  • 对于第一个样本(真实标签是 0),损失将仅计算 softmax[0][0](即类别 0 的概率),即 -log(0.7)
  • 对于第二个样本(真实标签是 1),损失将仅计算 softmax[1][1](即类别 1 的概率),即 -log(0.6)

因此,nn.CrossEntropyLoss 会自动从 y 中提取类别索引,并计算相应的损失。

总结:

  • 不需要手动进行独热编码,因为 nn.CrossEntropyLoss 会自动处理。
  • 标签 y 只需要是一个包含类别索引的张量(每个元素是一个整数,表示类别)。
  • nn.CrossEntropyLoss 会根据标签中的类别索引,从 y_hat 中提取对应的 logits,然后计算交叉熵损失。

通过这种方式,PyTorch 简化了损失计算的过程,让你不需要显式地进行独热编码的转换。

为什么在实验的时候我去掉了参数初始化对结果并没有很大影响?

在深度学习模型中,尤其是在多层感知机(MLP)中,初始化模型参数对于训练的收敛速度和结果有重要影响。虽然初始化通常有助于提升训练过程的稳定性和准确性,但在某些情况下,即使没有显式初始化,训练仍然可能得到较好的结果。这有以下几个原因:

1. PyTorch 默认的权重初始化

如果你没有手动初始化模型参数,PyTorch 会使用默认的初始化方式。对于 nn.Linear 层,PyTorch 默认使用的是 均匀分布初始化U(-sqrt(k), sqrt(k))),其中 k 是每一层的输入特征数(即该层的输入维度)。这种初始化方式通常可以在大部分场景下有效工作,尤其是在简单的任务或小型模型中。

2. 任务本身比较简单

MNIST 数据集是一个相对简单的任务,具有以下特点:

  • 样本相对简单(28x28 的灰度图像)。
  • 类别数量较少(10 个类别)。
  • 数据集规模较小(60,000 个训练样本)。

由于这些原因,即使没有特别优化初始化方式,模型仍然能在训练过程中较好地拟合数据,因此准确率可能不会受到显著影响。

3. 优化器的鲁棒性

现代优化器(如 SGD、Adam 等)通常具有较强的鲁棒性,能够在一定范围内有效地调整模型的参数,避免了初始化差异带来的过度影响。即使没有进行显式初始化,优化器也能够逐步调整模型的参数,从而避免梯度消失或梯度爆炸等问题,保证训练的顺利进行。

4. 训练过程中参数的调整

在模型训练初期,即使初始化不完美,随着训练的进行,网络的权重会在反向传播过程中逐步调整到合适的值。因此,即使开始时的参数较为随机,优化过程仍然能够找到有效的解决方案。这就是深度学习的一个特性:即使参数初始不理想,优化过程通常能通过梯度更新找到合适的解。

5. 初始化不影响最终收敛结果

对于一些简单的任务,模型可能在多个初始化条件下都能够达到一个相对接近的局部最优解。在这种情况下,即使没有手动初始化权重,模型也能收敛到较好的解。

总结:

  • 默认初始化(PyTorch 内部的初始化方式)通常已经能在很多简单的任务中有效工作,特别是像 MNIST 这样简单的图像分类任务。
  • 优化器的鲁棒性帮助模型调整参数,避免了初始化不完美时对结果产生显著影响。
  • 对于 MNIST 这种简单任务,初始化参数的不同可能不会导致显著差异,尤其是在训练的过程中,优化器能够找到较好的解。

然而,在一些更复杂的任务中,初始化的方式会直接影响模型的训练效率和性能。在这些任务中,精心设计的初始化(例如 Xavier、He 初始化等)能够帮助模型更快地收敛并避免训练过程中遇到的问题。


http://www.ppmy.cn/embedded/167773.html

相关文章

钉钉小程序(企业内部应用)开发下载预览文件

先转存钉盘,在下载 转存钉盘相关API为dd.saveFileToDingTalk调用钉盘预览文件的接口来预览:相关API为dd.previewFileInDingTalk在预览界面有下载的方式,可以直接下载 goPDF() {dd.saveFileToDingTalk({url: http://elinkshop.oss-cn-shanghai.ali…

学习Flask:[特殊字符] Day 4:REST API开发

学习目标:构建规范的API接口 from flask_restful import Api, Resourceapi Api(app)class PostAPI(Resource):def get(self, post_id):post Post.query.get_or_404(post_id)return {title: post.title,author: post.author.username}api.add_resource(PostAPI, /…

Spark技术系列(三):Spark算子全解析——从基础使用到高阶优化

Spark技术系列(三):Spark算子全解析——从基础使用到高阶优化 1. 算子核心概念与分类体系 1.1 算子本质解析 延迟执行机制:转换算子构建DAG,行动算子触发Job执行任务并行度:由RDD分区数决定(可通过spark.default.parallelism全局配置)执行位置优化:基于数据本地性的…

电脑显示屏亮度怎么调?电脑屏幕亮度调节步骤介绍

电脑屏幕亮度是指电脑显示器发出的光线的强度,它会影响我们的视觉效果和舒适度。电脑屏幕亮度过高或过低,都可能会对我们的眼睛造成伤害,所以我们需要根据不同的环境和需求,适时地调节电脑屏幕亮度。电脑屏幕亮度的调节方法有以下…

玩转 Netty : 如何设计高性能RPC通信组件

1、概述 前面我们学习了 Netty 的基本用法,以及内部涉及到的一些组件的概念,最后还开发了一款 HTTP 应用服务器,相信你已经知道了 Netty 是什么,可以用来做什么了。今天我们就重新回到 Cheese 中,我们今天的学习目标是…

HTML/CSS/JS

技术栈 前端 : HTML CSS JavaScript ES6 Nodejs npm vite vue3 router pinia axios element-plus 后端&#xff1a;HTTP xml Tomcat Servlet Request Response Cookie Sesssion Filter Listener MySQL JDBC Druid Jackson lombok jwt . HTML <!DOCTYPE html> 文档声…

Vue 3 + Vite 项目配置访问地址到服务器某个文件夹的解决方案

前言 在开发 Vue 3 Vite 项目时&#xff0c;我们经常需要将项目部署到服务器的某个特定文件夹下。例如&#xff0c;将项目部署到 /my-folder/ 目录下&#xff0c;而不是服务器的根目录。这时&#xff0c;我们需要对 Vite 和 Vue Router 进行一些配置&#xff0c;以确保项目能…

团队协作中的分支合并:构建高效开发流程的关键

项目场景 git pull origin 直接用 git pull 就能拉取远程仓库的分支 这是什么原理? git pull 命令会从远程仓库拉取最新的更改并合并到当前分支。它的具体行为取决于你是否指定了远程仓库和分支名称。 git pull 的默认行为 如果你只使用 git pull 而没有指定远程仓库和分支名…