梯度下降算法的计算过程

news/2025/1/18 21:41:45/

1 小批量梯度下降(Mini-Batch Gradient Descent, MBGD)

  • 1.1划分数据集为多个小批量。
  • 1.2前向传播:对于每个小批量中的所有样本进行一次前向传播,得到预测输出。
  • 1.3计算损失:然后计算这些预测输出相对于真实标签的总损失。通常是累加每个样本的损失来完成。
  • 1.4反向传播:执行反向传播以计算当前小批量上损失函数关于模型参数的梯度,这是通过自动微分工具自动完成,它会为每一个参数计算出一个梯度值。
  • 1.5计算平均梯度
    • 前向传播:对于一个给定的小批量(mini-batch),假设包含m个样本。对于每个样本 x i {x}_{i} xi,通过前向传播计算出预测值 y i ^ = f ( x i ; θ ) \hat{{y}_{i}}=f({x}_{i};\theta) yi^=f(xi;θ) y i ^ \hat{{y}_{i}} yi^是关于样本值和模型参数的函数。
    • 计算损失:基于预定义的损失函数计算预测值和标签值的差异,即损失。损失函数形式为: J ( x i , y i ; θ ) = L ( y i ^ , y i ) J({x}_{i},{y}_{i};\theta)=L(\hat{{y}_{i}}, {y}_{i}) J(xi,yi;θ)=L(yi^,yi) J J J是关于 ( y i ^ , y i ) (\hat{{y}_{i}}, {y}_{i}) (yi^,yi)的函数。
    • 反向传播:基于链式法则,从输出层开始,逐层向后计算梯度。具体来说,对于每一层的参数 θ j \theta_{j} θj,计算该参数的梯度 ∇ θ j J ( x i , y i ; θ j ) \nabla_{\theta_{j}}J({x}_{i},{y}_{i};\theta_{j}) θjJ(xi,yi;θj)
      ∂ L ∂ θ j = ∂ L ∂ y ^ ⋅ ∂ y ^ ∂ θ j \frac{\partial L}{\partial \theta_{j}}=\frac{\partial L}{\partial \hat{y}} \cdot \frac{\partial \hat{y}}{\partial \theta_{j}} θjL=y^Lθjy^
      由于每个小批量有多个样本,反向传播会得到一组梯度值,最终结果取梯度的平均值。
      ∇ θ j J ˉ = 1 m ∑ i = 1 m ∇ θ j J ( x i , y i ; θ j ) \nabla_{\theta_{j}}\bar{J}=\frac{1}{m}\sum_{i=1}^{m}\nabla_{\theta_{j}}J({x}_{i},{y}_{i};\theta_{j}) θjJˉ=m1i=1mθjJ(xi,yi;θj)
    • 参数更新:基于上述计算出的平均梯度更新模型参数。对于每个参数 θ j \theta_{j} θj,按照以下公式进行更新:
      θ j : = θ j − ϵ ∇ θ j J ˉ \theta_{j} :=\theta_{j} - \epsilon\nabla_{\theta_{j}}\bar{J} θj:=θjϵθjJˉ,其中 ϵ \epsilon ϵ是模型学习率。

2 带动量的梯度下降

  • 2.1设置学习率 ϵ \epsilon ϵ和动量参数 α \alpha α
  • 2.2 计算当前小批量的平均梯度
    g = 1 m ∑ i = 1 m ∇ θ j J ( x i , y i ; θ j ) g=\frac{1}{m}\sum_{i=1}^{m}\nabla_{\theta_{j}}J({x}_{i},{y}_{i};\theta_{j}) g=m1i=1mθjJ(xi,yi;θj)
  • 2.3 计算速度更新
    ν ← α ν − ϵ g \nu \gets \alpha\nu - \epsilon g νανϵg
  • 2.4更新参数
    θ ← θ + ν \theta \gets \theta + \nu θθ+ν

http://www.ppmy.cn/news/1564242.html

相关文章

Pycharm报错:DeprecationWarning: sipPyTypeDict() is deprecated

原因:这个警告是由SIP库引发的,通常不会导致应用程序出现问题。警告表明应用程序中使用了不推荐使用的SIP函数,但在大多数情况下,这些警告可以被忽略。 SIP是用于创建Python和C之间的桥接的库,用于让Python扩展能够与…

Go语言之路————数组、切片、map

Go语言之路————数组、切片、map 前言一、数组二、切片三、map 前言 我是一名多年Java开发人员,因为工作需要现在要学习go语言,Go语言之路是一个系列,记录着我从0开始接触Go,到后面能正常完成工作上的业务开发的过程&#xff…

elasticsearch线程池配置

在Elasticsearch中,默认的线程池配置如下: search线程池 用途:用于处理搜索请求。 特点: 类型为fixed,即固定大小的线程池。 线程数根据分配给Elasticsearch的处理器数量动态计算,以确保搜索请求能够并行…

matlab中的griddata函数

在Matlab中,griddata函数用于对二维或三维散点数据进行插值。griddata函数支持多种插值方法,其中包括natural方法。以下是关于griddata函数与natural插值方法的关系的详细说明: griddata函数概述 griddata函数主要用于将不规则分布的数据点…

大模型微调介绍-Prompt-Tuning

提示微调入门 NLP四范式 第一范式 基于「传统机器学习模型」的范式,如TF-IDF特征朴素贝叶斯等机器算法. 第二范式 基于「深度学习模型」的范式,如word2vec特征LSTM等深度学习算法,相比于第一范式,模型准确有所提高&#xff0c…

数据结构-栈队列OJ题

文章目录 一、有效的括号二、用队列实现栈三、用栈实现队列四、设计循环队列 一、有效的括号 (链接:ValidParentheses) 这道题用栈这种数据结构解决最好,因为栈有后进先出的性质。简单分析一下这道题:所给字符串不是空的也就是一定至少存在一…

华为手机改ip地址能改定位吗

‌在数字化时代,手机不仅是通讯工具,更是我们日常生活的得力助手。从地图导航到社交媒体,手机定位服务无处不在。然而,有时我们可能出于隐私保护或其他原因,希望更改手机的IP地址,并好奇这是否能同时改变手…

Spring Boot中的自动配置原理是什么

Spring Boot 自动配置原理 Spring Boot 的自动配置机制基于 条件化配置,通过 EnableAutoConfiguration 注解来启用。自动配置的核心原理是 基于类路径和环境条件来推断所需要的配置,Spring Boot 会根据项目中引入的依赖和当前环境来自动装配相关的配置项…