神经网络反向传播交叉熵 计算损失函数对输出层偏置b2的梯度

server/2024/12/22 21:43:36/

本文是交叉熵损失函数为代表的两层神经网络反向传播量化求导计算公式中的一个公式,单独拿出来做一下解释说明。


公式 8-15反向传播算法中,计算损失函数对输出层偏置 b 2 b_2 b2 的梯度。这个梯度用于指导偏置的更新,从而最小化损失函数 L L L。我们将在此详细推导和解释公式 8-15。

公式 8-15 的表达式:

∂ L ∂ b 2 = 1 m ∑ i = 1 m ( a 2 − y ) (8-15) \frac{\partial L}{\partial b_2} = \frac{1}{m} \sum_{i=1}^{m} (a_2 - y) \tag{8-15} b2L=m1i=1m(a2y)(8-15)

1. 背景概念

  • L L L:损失函数,通常在分类问题中使用交叉熵损失。
  • b 2 b_2 b2:输出层的偏置项,位于输出神经元的输入前加上。
  • a 2 a_2 a2:输出层的激活值,是通过激活函数(如 sigmoid)计算得到的,代表模型的预测概率。
  • y y y:真实标签,代表样本的真实类别(0 或 1)。
  • m m m:训练集中的样本数量。
  • i i i:第 i i i 个样本的索引。

2. 公式的含义

公式 8-15 计算的是损失函数 L L L 对输出层偏置 b 2 b_2 b2 的导数,也就是通过反向传播算法,损失函数如何影响输出层的偏置项。这个梯度用于反向传播过程中的梯度更新。

偏置 b 2 b_2 b2 是输出层的一个独立参数,它不是与输入数据直接相连的,而是加在输出层的输入上。因此,它的梯度只与输出层的误差(即模型预测值与真实值之间的差异)有关。

3. 推导过程

为了推导公式 8-15,我们需要应用链式法则来计算损失函数对 b 2 b_2 b2 的导数。

1. 损失函数对偏置 b 2 b_2 b2 的导数

损失函数 L L L 对输出层偏置 b 2 b_2 b2 的导数可以通过链式法则写为:
∂ L ∂ b 2 = ∂ L ∂ z 2 ⋅ ∂ z 2 ∂ b 2 \frac{\partial L}{\partial b_2} = \frac{\partial L}{\partial z_2} \cdot \frac{\partial z_2}{\partial b_2} b2L=z2Lb2z2

其中:

  • ∂ L ∂ z 2 \frac{\partial L}{\partial z_2} z2L 是损失函数对输出层加权输入 z 2 z_2 z2 的导数,这个导数反映了模型输出层的误差。
  • ∂ z 2 ∂ b 2 \frac{\partial z_2}{\partial b_2} b2z2 是加权输入 z 2 z_2 z2 对偏置 b 2 b_2 b2 的导数。
2. 计算 ∂ z 2 ∂ b 2 \frac{\partial z_2}{\partial b_2} b2z2

输出层的加权输入 z 2 z_2 z2 的计算公式为:
z 2 = w 2 ⋅ a 1 + b 2 z_2 = w_2 \cdot a_1 + b_2 z2=w2a1+b2

因此, z 2 z_2 z2 对偏置 b 2 b_2 b2 的导数是:
∂ z 2 ∂ b 2 = 1 \frac{\partial z_2}{\partial b_2} = 1 b2z2=1

这个结果是因为偏置 b 2 b_2 b2 直接加到输出层的输入上,它对 z 2 z_2 z2 的变化直接影响,而与输入数据无关。

3. 计算 ∂ L ∂ z 2 \frac{\partial L}{\partial z_2} z2L

根据公式 8-13,我们已经知道:
∂ L ∂ z 2 = a 2 − y \frac{\partial L}{\partial z_2} = a_2 - y z2L=a2y

这个导数表示模型预测值 a 2 a_2 a2 和真实值 y y y 之间的差异。它衡量了输出层的误差大小。

4. 结合链式法则

现在我们可以将这些结果代入链式法则中,得到:
∂ L ∂ b 2 = ( a 2 − y ) ⋅ 1 = a 2 − y \frac{\partial L}{\partial b_2} = (a_2 - y) \cdot 1 = a_2 - y b2L=(a2y)1=a2y

这个表达式表示损失函数对输出层偏置 b 2 b_2 b2 的梯度等于模型输出层的误差 a 2 − y a_2 - y a2y

5. 对所有样本求和并取平均

在实际训练过程中,我们通常对多个样本进行计算,因此需要对所有样本的梯度求和并取平均。假设有 m m m 个样本,则最终的梯度表达式为:
∂ L ∂ b 2 = 1 m ∑ i = 1 m ( a 2 ( i ) − y ( i ) ) \frac{\partial L}{\partial b_2} = \frac{1}{m} \sum_{i=1}^{m} (a_2^{(i)} - y^{(i)}) b2L=m1i=1m(a2(i)y(i))

这里, i i i 表示第 i i i 个样本, a 2 ( i ) a_2^{(i)} a2(i) 是第 i i i 个样本的预测值, y ( i ) y^{(i)} y(i) 是第 i i i 个样本的真实标签。

4. 公式的解释

∂ L ∂ b 2 = 1 m ∑ i = 1 m ( a 2 − y ) \frac{\partial L}{\partial b_2} = \frac{1}{m} \sum_{i=1}^{m} (a_2 - y) b2L=m1i=1m(a2y)

这个公式表明,输出层偏置 b 2 b_2 b2 的梯度等于所有样本的输出误差 a 2 − y a_2 - y a2y 的平均值。偏置项的梯度只与输出层的误差有关,它不依赖于输入数据或隐藏层的激活值。

在梯度下降过程中,我们将使用这个梯度来更新输出层的偏置。通过调整偏置 b 2 b_2 b2,我们可以减少输出层的预测误差。

5. 直观理解

偏置 b 2 b_2 b2 是加在输出层的神经元上的一个常量,因此它的更新取决于输出层的误差。在每个样本的预测中,偏置 b 2 b_2 b2 的更新方向和大小都依赖于模型的预测值 a 2 a_2 a2 和真实值 y y y 之间的差异。

  • 如果预测 a 2 a_2 a2 高于真实值 y y y,即 a 2 > y a_2 > y a2>y,那么梯度 a 2 − y a_2 - y a2y 为正,表示需要减少偏置的值,从而减小预测。
  • 如果预测 a 2 a_2 a2 低于真实值 y y y,即 a 2 < y a_2 < y a2<y,那么梯度 a 2 − y a_2 - y a2y 为负,表示需要增加偏置的值,从而增大预测。

通过对多个样本的误差求平均,我们可以平滑梯度的更新,使偏置 b 2 b_2 b2 的调整更稳定。

6. 总结

公式 8-15 通过链式法则,计算出损失函数对输出层偏置 b 2 b_2 b2 的梯度。公式表达了偏置的梯度是所有样本的输出误差 ( a 2 − y ) (a_2 - y) (a2y) 的平均值。这个梯度用于反向传播过程中,通过梯度下降更新偏置,使得神经网络的损失函数逐步减小,最终提高模型的预测准确性。


http://www.ppmy.cn/server/132175.html

相关文章

尚硅谷rabbitmq2024 集群篇仲裁队列 第52节 答疑

我们希望创建一个队列&#xff0c;队列分布在各个节点上&#xff0c;仲裁队列很好的解决了这个问题.那么在仲裁队列之前&#xff0c;创建一个队列&#xff0c;队列不是分布在各个节点上的吗&#xff1f; 在RabbitMQ中&#xff0c;默认情况下创建的队列是“普通队列”&#xff0…

【Linux】ioctl分析

简介 一个字符设备驱动通常会实现常规的open、release、read和write接口&#xff0c;但是如果需要扩展新的功能&#xff0c;通常以ioctl接口的方式实现。 #mermaid-svg-uY8EyPklf5e4ZMQo {font-family:"trebuchet ms",verdana,arial,sans-serif;font-size:16px;fill…

简单实现手机、电脑相互操作

1、从手机截图到sdcard 2、将图片导出到PC 3、从PC加载图片 4、开启定时器 5、操作电脑UI事件 1、 private static void takeScreenshot(String path) {long t1 System.currentTimeMillis();String command "adb devices"; // 替换为你需要执行的shell命令Str…

C#中判断的应用说明二(switch语句)

一.判断的定义说明 判断结构要求程序员指定一个或多个要评估或测试的条件&#xff0c;以及条件为真时要执行的语句&#xff08;必需的&#xff09;和条件为假时要执行的语句&#xff08;可选的&#xff09;。下面是大多数编程语言中典型的判断结构的一般形式&#xff1a; 二.判…

百度下拉框出词技术解密:72小时出下拉词软件原理分享

如何才能刷下拉词&#xff1f;这个问题一直是企业做流量时最纠结的问题&#xff0c;百度下拉词作为百度搜索体验中的一项智能化功能&#xff0c;极大地方便了用户快速完成搜索&#xff0c;也成为了企业在搜索引擎优化&#xff08;SEO&#xff09;策略中的重要流量入口。通过研究…

int QSqlQuery::size() const

返回结果的大小&#xff08;返回的行数&#xff09; 或者返回-1 &#xff08;如果大小不能被决定 或者 数据库不支持报告查询的大小信息&#xff09; 注意&#xff1a;对于非查询语句&#xff0c;将返回-1&#xff08;isSelect()返回false&#xff09; 如果查询不是活跃的&…

vue项目中使用websocket

一、单文件中引入使用 <template></template> <script>export default {websocket: true, // 标志需要使用WebSocketdata () {return {ws: null}},created () {this.ws new WebSocket(ws://127.0.0.1:8000); // ws服务地址this.ws.onopen () > {// 接收…

Django项目的创建及说明(详细图解版)

Django项目的创建及说明 1、安装Django2、创建项目2.1、利用终端创建项目2.2、利用Pycharm企业版创建项目 3、默认文件介绍 1、安装Django 在终端输入下述命令行。 pip install django安装成功后执行如下命令查看Django是否安装好&#xff0c;若正确显示出Django版本号则安装…