神经网络反向传播交叉熵 计算损失函数对隐藏层激活值a1的梯度

news/2024/10/21 19:33:29/

本文是交叉熵损失函数为代表的两层神经网络反向传播量化求导计算公式中的一个公式,单独拿出来做一下解释说明。


公式 8-16反向传播算法中,用于计算损失函数对隐藏层激活值 a 1 a_1 a1 的梯度。在反向传播过程中,损失函数对隐藏层激活值的梯度是非常重要的一步,因为它将误差从输出层传递到隐藏层,最终用于调整隐藏层的权重和偏置。

公式 8-16 的表达式为:
∂ L ∂ a 1 = ( a 2 − y ) w 2 (8-16) \frac{\partial L}{\partial a_1} = (a_2 - y) w_2 \tag{8-16} a1L=(a2y)w2(8-16)

1. 公式中的符号解释

  • L L L:损失函数,通常是交叉熵损失函数。
  • a 1 a_1 a1:隐藏层的激活值,这是通过激活函数(如 sigmoid、ReLU)计算得到的隐藏层输出。
  • a 2 a_2 a2:输出层的激活值,即模型的预测结果(通过输出层的激活函数计算得到)。
  • w 2 w_2 w2:输出层权重矩阵,它连接隐藏层和输出层神经元。
  • y y y:真实标签,表示样本的真实类别。
  • ∂ L ∂ a 1 \frac{\partial L}{\partial a_1} a1L:损失函数对隐藏层激活值的梯度,这个梯度用于进一步计算损失函数对隐藏层权重和偏置的影响。

2. 公式的含义

公式 8-16 表示的是损失函数对隐藏层激活值 a 1 a_1 a1 的梯度。通过链式求导,损失函数对隐藏层的梯度可以通过输出层的误差 a 2 − y a_2 - y a2y输出层权重 w 2 w_2 w2 来计算。这个公式的推导与反向传播的链式法则密切相关,因为它需要将输出层的误差“反向传播”到隐藏层。

3. 推导过程

为了推导公式 8-16,我们可以通过链式法则一步一步进行计算。

1. 链式法则的应用

我们知道,损失函数 L L L 依赖于输出层的激活值 a 2 a_2 a2,而输出层的激活值 a 2 a_2 a2 又依赖于隐藏层的激活值 a 1 a_1 a1。根据链式法则,损失函数对隐藏层激活值 a 1 a_1 a1 的导数可以写为:
∂ L ∂ a 1 = ∂ L ∂ z 2 ⋅ ∂ z 2 ∂ a 1 \frac{\partial L}{\partial a_1} = \frac{\partial L}{\partial z_2} \cdot \frac{\partial z_2}{\partial a_1} a1L=z2La1z2

其中:

  • ∂ L ∂ z 2 \frac{\partial L}{\partial z_2} z2L 是损失函数对输出层加权输入 z 2 z_2 z2 的导数,表示的是输出层的误差。
  • ∂ z 2 ∂ a 1 \frac{\partial z_2}{\partial a_1} a1z2 是输出层加权输入 z 2 z_2 z2 对隐藏层激活值 a 1 a_1 a1 的导数。
2. 计算 ∂ z 2 ∂ a 1 \frac{\partial z_2}{\partial a_1} a1z2

输出层的加权输入 z 2 z_2 z2 是通过隐藏层激活值 a 1 a_1 a1 和输出层权重 w 2 w_2 w2 线性组合得到的,即:
z 2 = w 2 ⋅ a 1 + b 2 z_2 = w_2 \cdot a_1 + b_2 z2=w2a1+b2

因此, z 2 z_2 z2 a 1 a_1 a1 的导数为:
∂ z 2 ∂ a 1 = w 2 \frac{\partial z_2}{\partial a_1} = w_2 a1z2=w2

这意味着,输出层的加权输入 z 2 z_2 z2 对隐藏层的激活值 a 1 a_1 a1 的导数等于输出层的权重 w 2 w_2 w2

3. 计算 ∂ L ∂ z 2 \frac{\partial L}{\partial z_2} z2L

根据公式 8-13,损失函数 L L L 对输出层加权输入 z 2 z_2 z2 的导数为:
∂ L ∂ z 2 = a 2 − y \frac{\partial L}{\partial z_2} = a_2 - y z2L=a2y

这个导数表示的是输出层的误差,即模型的预测值 a 2 a_2 a2 和真实标签 y y y 之间的差异。

4. 结合链式法则

现在我们可以将这些结果代入链式法则,得到:
∂ L ∂ a 1 = ( a 2 − y ) ⋅ w 2 \frac{\partial L}{\partial a_1} = (a_2 - y) \cdot w_2 a1L=(a2y)w2

这个结果表示,损失函数对隐藏层激活值 a 1 a_1 a1 的梯度等于输出层误差 a 2 − y a_2 - y a2y 和输出层权重 w 2 w_2 w2 的乘积。

4. 公式的解释

公式 8-16 表明:

  • 损失函数对隐藏层激活值的梯度是通过输出层的误差输出层权重反向传播回来的。
  • 输出层的误差 a 2 − y a_2 - y a2y 表示模型的预测结果与真实标签之间的差异。
  • 输出层的权重 w 2 w_2 w2 用于将这个误差传播回隐藏层,以调整隐藏层的输出。

5. 直观理解

反向传播中,输出层的误差 a 2 − y a_2 - y a2y 是最直接的反馈,它表示模型预测的结果和真实标签之间的偏差。这个误差通过输出层的权重 w 2 w_2 w2 反向传播到隐藏层,影响隐藏层的激活值 a 1 a_1 a1 的更新。

  • 如果输出层的误差较大,即 a 2 a_2 a2 y y y 的差异较大,那么损失函数对隐藏层激活值的梯度也会较大。这意味着隐藏层需要做出较大的调整来修正这个误差。
  • 如果输出层的误差较小,即 a 2 a_2 a2 y y y 非常接近,说明模型的预测较准确,隐藏层的梯度就会较小,表示隐藏层的输出已经接近于正确的值,不需要大幅度调整。

通过公式 8-16,我们将输出层的误差传播回隐藏层,使得隐藏层能够感知到模型的整体误差,并相应地调整自身的输出。

6. 反向传播中的作用

公式 8-16 是反向传播算法中的关键一步。反向传播的基本思想是将输出层的误差逐层传递回去,最终传递到每一层的权重和偏置。这一步中的梯度 ∂ L ∂ a 1 \frac{\partial L}{\partial a_1} a1L 用于进一步计算隐藏层的权重 w 1 w_1 w1 和偏置 b 1 b_1 b1 的梯度,从而调整隐藏层的参数。

  • 反向传播的下一步,我们将根据这个梯度 ∂ L ∂ a 1 \frac{\partial L}{\partial a_1} a1L 来计算隐藏层权重和偏置的更新。
  • 隐藏层的权重和偏置调整之后,模型的整体误差会减小,模型的预测精度逐渐提高。

7. 总结

公式 8-16 通过链式法则,计算了损失函数对隐藏层激活值 a 1 a_1 a1 的梯度。这个梯度由输出层的误差 a 2 − y a_2 - y a2y输出层权重 w 2 w_2 w2 决定。通过这个公式,反向传播算法将误差从输出层传播到隐藏层,进而用于更新隐藏层的参数,最终使模型的整体误差减小,提高模型的预测性能。


http://www.ppmy.cn/news/1540880.html

相关文章

基于MinIO配置bucket,用于文件下载和浏览

文章目录 引言I 配置文件浏览安装MinIO配置自启动服务访问权限配置文件浏览访问地址文件下载地址II 知识扩展MinIO内置访问策略只读策略只写策略读写策略diagnosticsconsoleAdmin引言 需求:文件下载用于OTA升级,文件浏览用于产品展示。 实现方案:基于MinIO配置bucket访问权…

Android Automotive 获得谷歌地图事故报告功能

Android Automotive 迎来了谷歌地图的实时事故报告功能,这一更新标志着它与 Android Auto 的功能差距进一步缩小。 Android Auto 主要是通过手机与汽车的连接来提供服务,而Android Automotive 则是为汽车量身定制的系统——这在软件更新和用户体验上带来…

[Unity Demo]从零开始制作空洞骑士Hollow Knight第十四集:制作新的场景以及制作创建切换管理系统

提示:文章写完后,目录可以自动生成,如何生成可参考右边的帮助文档 文章目录 前言一、制作新的场景 1.重新翻新各种Sprite2.制作地图前期应该做的事情3.疯狂的制作地图二、制作场景切换管理系统 1.制作场景切换点TransitionPoint2.切换场景时的…

Android Framework AMS(08)service组件分析-2(startService和StopService关键流程分析)

该系列文章总纲链接:专题总纲目录 Android Framework 总纲 本章关键点总结 & 说明: 说明:上一章节主要解读应用层service组件启动的2种方式startService和bindService,以及从APP层到AMS调用之间的打通。本章节主要关注service…

深入解析缓存与数据库数据不一致问题

缓存层是提高系统响应速度和扩展性的关键组件。然而,缓存层的引入也带来了数据一致性的挑战。 当数据库中的数据发生变化时,如何确保这些变化能够及时且准确地反映到缓存中,是确保用户体验和系统可靠性的重要问题。 1. 数据一致性 首先&am…

go压缩的使用

基础:使用go创建一个zip func base(path string) {// 创建 zip 文件zipFile, err : os.Create("test.zip")if err ! nil {panic(err)}defer zipFile.Close()// 创建一个新的 *Writer 对象zipWriter : zip.NewWriter(zipFile)defer zipWriter.Close()// 创…

原理代码解读:基于DiT结构视频生成模型的ControlNet

Diffusion Models视频生成-博客汇总 前言:相比于基于UNet结构的视频生成模型,DiT结构的模型最大的劣势在于生态不够完善,配套的ControlNet、IP-Adapter等开源权重不多,导致难以落地。最近DiT-based 5B的ControlNet开源了,相比于传统的ControlNet有不少改进点,这篇博客将从…

RabbitMQ 作为消息中间件,实现了支付消息的异步发送和接收, 同步和异步相比 响应速度具体比较

在支付场景中,使用 RabbitMQ 实现消息的异步发送和接收与同步处理相比,响应速度和整体系统性能会有显著的不同。以下是同步和异步方式在响应速度上的详细比较: 1. 同步处理方式 在同步模式下,支付消息的处理流程通常是&#xf…