【漫话机器学习系列】119.小批量随机梯度方法

embedded/2025/3/6 17:09:15/

1. 引言

机器学习和深度学习中,梯度下降(Gradient Descent)是一种常见的优化算法,用于调整模型参数以最小化损失函数。然而,在处理大规模数据集时,使用传统的梯度下降(GD)可能会面临计算成本高、收敛速度慢等问题。因此,引入了小批量随机梯度下降(Mini-Batch Stochastic Gradient Descent,MB-SGD),它结合了全批量梯度下降(Batch GD)和随机梯度下降(SGD)的优点,成为深度学习训练中的标准方法。

本文将详细介绍小批量随机梯度方法的基本概念、数学原理、优缺点及其应用,并通过示例代码演示其实际使用方法。


2. 什么是小批量随机梯度下降?

小批量随机梯度下降(Mini-Batch SGD)是一种改进的梯度下降方法,它在每次参数更新时,只使用数据集中的一个小部分(小批量)来计算梯度,而不是整个数据集。

具体来说,小批量随机梯度下降的工作流程如下:

  1. 从数据集中随机抽取一个小批量(Mini-Batch)样本,大小通常为 32、64、128 等。
  2. 计算该小批量上的梯度,然后更新模型参数。
  3. 重复上述步骤,直到遍历整个数据集(一个 epoch)
  4. 重复多个 epoch,直到模型收敛

这一策略避免了全批量梯度下降计算量过大的问题,同时比单样本的随机梯度下降更稳定。


3. 小批量随机梯度下降的数学原理

3.1. 梯度下降基本公式

梯度下降的核心思想是沿着负梯度方向更新参数,从而最小化损失函数 J(θ)。其基本更新公式如下:

\theta = \theta - \alpha \nabla J(\theta)

其中:

  • θ 是模型参数
  • α 是学习率(learning rate)
  • ∇J(θ) 是损失函数关于参数的梯度

3.2. 全批量梯度下降(Batch Gradient Descent)

全批量梯度下降使用整个数据集来计算梯度:

\nabla J(\theta) = \frac{1}{N} \sum_{i=1}^{N} \nabla J_i(\theta)

其中 N 是数据集的大小。这种方法计算精确,但当数据量过大时,计算开销很高。

3.3. 随机梯度下降(Stochastic Gradient Descent, SGD)

随机梯度下降(SGD)每次只使用一个样本来计算梯度:

\theta = \theta - \alpha \nabla J_i(\theta)

由于仅使用一个样本进行更新,计算速度快,但梯度更新噪声较大,导致收敛不稳定。

3.4. 小批量随机梯度下降(Mini-Batch SGD)

小批量随机梯度下降在每次更新时使用一个小批量 B(包含多个样本)来计算梯度:

\theta = \theta - \alpha \frac{1}{|B|} \sum_{i \in B} \nabla J_i(\theta)

其中∣B∣ 是小批量的大小。该方法在计算效率收敛稳定性之间取得了良好的平衡。


4. 小批量随机梯度下降的优缺点

4.1. 优势

  • 减少计算开销:相比全批量梯度下降,小批量方法可以显著降低计算成本。
  • 提高收敛稳定性:相比随机梯度下降,小批量方法的梯度估计更加稳定,能更快地收敛。
  • 可利用并行计算:可以使用 GPU 进行矩阵运算,提高训练效率。
  • 易于处理大规模数据集:能够在数据量较大的情况下高效训练模型。

4.2. 劣势

  • 超参数敏感:小批量大小(batch size)和学习率的选择会影响模型性能。
  • 计算复杂度仍然较高:虽然比全批量下降快,但仍然比纯随机梯度下降计算量大。
  • 收敛可能不如全批量方法:由于梯度估计存在一定噪声,可能会导致收敛到局部最优解。

5. 代码示例

我们使用 Python 代码来实现小批量随机梯度下降。

5.1. 使用 NumPy 手动实现 Mini-Batch SGD

import numpy as np# 生成模拟数据
np.random.seed(42)
X = np.random.rand(100, 1)  # 100个样本,1个特征
y = 4 * X + np.random.randn(100, 1) * 0.2  # 线性关系 y = 4x + 噪声# 初始化参数
theta = np.random.randn(2, 1)
learning_rate = 0.1
epochs = 100
batch_size = 10# 添加偏置项
X_b = np.c_[np.ones((100, 1)), X]  # Mini-Batch SGD 训练
for epoch in range(epochs):shuffled_indices = np.random.permutation(100)  # 随机打乱数据X_b_shuffled = X_b[shuffled_indices]y_shuffled = y[shuffled_indices]for i in range(0, 100, batch_size):X_batch = X_b_shuffled[i:i + batch_size]y_batch = y_shuffled[i:i + batch_size]gradients = 2 / batch_size * X_batch.T.dot(X_batch.dot(theta) - y_batch)theta -= learning_rate * gradientsprint(f"训练后的参数: {theta}")

运行结果

训练后的参数: [[0.04320936][3.90884737]]

此代码实现了:

  1. 生成数据集并添加噪声。
  2. 使用 Mini-Batch SGD 进行参数更新。
  3. 训练完成后输出最终的参数值。

5.2. 使用 PyTorch 实现 Mini-Batch SGD

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset# 生成数据
X = torch.rand(100, 1)
y = 4 * X + torch.randn(100, 1) * 0.2# 构建数据集
dataset = TensorDataset(X, y)
dataloader = DataLoader(dataset, batch_size=10, shuffle=True)# 定义模型
model = nn.Linear(1, 1)
loss_fn = nn.MSELoss()
optimizer = optim.SGD(model.parameters(), lr=0.1)# 训练
epochs = 100
for epoch in range(epochs):for batch_X, batch_y in dataloader:optimizer.zero_grad()predictions = model(batch_X)loss = loss_fn(predictions, batch_y)loss.backward()optimizer.step()print(f"训练后的权重: {model.weight.data}, 偏置: {model.bias.data}")

运行结果

训练后的权重: tensor([[3.9055]]), 偏置: tensor([0.0890])

PyTorch 实现更加简洁,并且支持自动求导和 GPU 加速。


6. 结论

小批量随机梯度下降(Mini-Batch SGD)是一种高效且稳定的优化方法,它结合了全批量梯度下降的稳定性和随机梯度下降的计算效率,是深度学习训练中的标准方法。在实际应用中,需要通过调整学习率、批量大小和优化策略来获得最佳性能。

对于大规模数据集和深度学习任务,小批量方法能够显著提高训练速度,并支持并行计算,使得它成为现代机器学习的核心优化算法之一。


http://www.ppmy.cn/embedded/170528.html

相关文章

LangChain-08 Query SQL DB 通过GPT自动查询SQL

我们需要下载一个 LangChain 官方提供的本地小数据库。 安装依赖 SQL: https://raw.githubusercontent.com/lerocha/chinook-database/master/ChinookDatabase/DataSources/Chinook_Sqlite.sql Shell: pip install --upgrade --quiet langchain-core langchain-community la…

[内网安全] Windows 本地认证 — NTLM 哈希和 LM 哈希

关注这个专栏的其他相关笔记:[内网安全] 内网渗透 - 学习手册-CSDN博客 0x01:SAM 文件 & Windows 本地认证流程 0x0101:SAM 文件简介 Windows 本地账户的登录密码是存储在系统本地的 SAM 文件中的,在登录 Windows 的时候&am…

GCC RISCV 后端 -- cc1 入口

GCC编译工具链中的 gcc 可执行程序,实际上是个驱动程序(Driver),其根据输入的参数,然后调用其它不同的程序,对输入文件进行处理,包括编译、链接等。可以通过以下命令查看: gcc -v h…

UV安装GPU版本PyTorch

经过同事推荐,开始尝试使用uv管理Python环境,效果相当不错。 安装PyTorch遇到的问题 但在安装PyTorch时,采用默认的uv add方式会报错,而使用uv pip install安装PyTorch的cuda版本,虽然没有问题,但并不能同…

C++(蓝桥杯常考点)

前言:这个是针对于蓝桥杯竞赛常考的C内容,容器这些等下棋期再讲 C 在DEVC中注释和取消注释的方法:ctrl/ ASCII值(常用的): A-Z:65-90 a-z:97-122 0-9:48-57 换行/n:10科学计数法:eg&#xff1a…

10.RabbitMQ集群

十、集群与高可用 RabbitMQ 的集群分两种模式,一种是默认集群模式,一种是镜像集群模式; 在RabbitMQ集群中所有的节点(一个节点就是一个RabbitMQ的broker服务器) 被归为两类:一类是磁盘节点,一类是内存节点; 磁盘节点会把集群的所有信息(比如交换机、绑…

vscode远程连接ubuntu/Linux(虚拟机同样适用)

前言 在现代开发环境中,远程工作和跨平台开发变得越来越普遍。Visual Studio Code(VSCode)作为一个流行的代码编辑器,提供了强大的远程开发功能,使得开发者能够高效地连接和管理远程 Linux 服务器上的项目。通过 VSCod…

【MySQL】索引|作用|底层数据结构|常见问题

目录 1.概念 2.为何引入 3.使用 (1)查看索引 (2)创建索引(危险操作) (3)删除索引(危险操作) 4.使用场景 🔥5.底层数据结构(核…