神经网络反向传播算法公式推导

server/2024/11/24 15:28:36/

要推导反向传播算法,并了解每一层的参数梯度如何计算,以及每一层的梯度受到哪些值的影响,我们使用一个简单的神经网络结构:

  • 输入层有2个节点
  • 一个有2个节点的隐藏层,激活函数是ReLU
  • 一个输出节点,激活函数是线性激活(即没有激活函数)

假设权重矩阵和偏置如下:

  • 输入层到隐藏层的权重矩阵 W 1 W_1 W1 2 × 2 2 \times 2 2×2
  • 隐藏层的偏置向量 b 1 b_1 b1 2 × 1 2 \times 1 2×1
  • 隐藏层到输出层的权重矩阵 W 2 W_2 W2 2 × 1 2 \times 1 2×1
  • 输出层的偏置向量 b 2 b_2 b2是一个标量

输入为 x = [ x 1 , x 2 ] x = [x_1, x_2] x=[x1,x2],期望输出为 y y y,损失函数为均方误差(MSE)。

前向传播:

  1. 计算隐藏层的输入:
    z 1 = W 1 ⋅ x + b 1 z_1 = W_1 \cdot x + b_1 z1=W1x+b1
  2. 计算隐藏层的激活:
    a 1 = ReLU ( z 1 ) a_1 = \text{ReLU}(z_1) a1=ReLU(z1)
  3. 计算输出层的输入:
    z 2 = W 2 T ⋅ a 1 + b 2 z_2 = W_2^T \cdot a_1 + b_2 z2=W2Ta1+b2
  4. 输出值:
    y ^ = z 2 \hat{y} = z_2 y^=z2
  5. 计算损失:
    L = 1 2 ( y ^ − y ) 2 L = \frac{1}{2} (\hat{y} - y)^2 L=21(y^y)2

反向传播:

  1. 计算输出层的梯度:

    • 损失函数对输出层输入的梯度:
      ∂ L ∂ z 2 = y ^ − y \frac{\partial L}{\partial z_2} = \hat{y} - y z2L=y^y
  2. 计算从输出层到隐藏层的梯度:

    • 隐藏层激活对权重的梯度:
      ∂ L ∂ W 2 = ∂ L ∂ z 2 ⋅ a 1 \frac{\partial L}{\partial W_2} = \frac{\partial L}{\partial z_2} \cdot a_1 W2L=z2La1
    • 隐藏层激活对偏置的梯度:
      ∂ L ∂ b 2 = ∂ L ∂ z 2 \frac{\partial L}{\partial b_2} = \frac{\partial L}{\partial z_2} b2L=z2L
  3. 计算隐藏层的梯度:

    • 损失函数对隐藏层激活的梯度:
      ∂ L ∂ a 1 = W 2 ⋅ ∂ L ∂ z 2 \frac{\partial L}{\partial a_1} = W_2 \cdot \frac{\partial L}{\partial z_2} a1L=W2z2L
    • 隐藏层对隐藏层输入的梯度(ReLU的梯度):
      ∂ L ∂ z 1 = ∂ L ∂ a 1 ⋅ ReLU ′ ( z 1 ) \frac{\partial L}{\partial z_1} = \frac{\partial L}{\partial a_1} \cdot \text{ReLU}'(z_1) z1L=a1LReLU(z1)
      • ReLU梯度 ReLU ′ ( z 1 ) \text{ReLU}'(z_1) ReLU(z1) z 1 > 0 z_1 > 0 z1>0时为1,否则为0
  4. 计算从输入层到隐藏层的梯度:

    • 输入对权重的梯度:
      ∂ L ∂ W 1 = ∂ L ∂ z 1 ⋅ x T \frac{\partial L}{\partial W_1} = \frac{\partial L}{\partial z_1} \cdot x^T W1L=z1LxT
    • 输入对偏置的梯度:
      ∂ L ∂ b 1 = ∂ L ∂ z 1 \frac{\partial L}{\partial b_1} = \frac{\partial L}{\partial z_1} b1L=z1L

详细推导实例:

假设:

  • x = [ 1 , 2 ] x = [1, 2] x=[1,2]
  • y = 3 y = 3 y=3
  • W 1 = [ 0.5 0.2 0.3 0.7 ] W_1 = \begin{bmatrix} 0.5 & 0.2 \\ 0.3 & 0.7 \end{bmatrix} W1=[0.50.30.20.7]
  • b 1 = [ 0.1 0.2 ] b_1 = \begin{bmatrix} 0.1 \\ 0.2 \end{bmatrix} b1=[0.10.2]
  • W 2 = [ 0.6 0.9 ] W_2 = \begin{bmatrix} 0.6 \\ 0.9 \end{bmatrix} W2=[0.60.9]
  • b 2 = 0.3 b_2 = 0.3 b2=0.3

前向传播:
1.
z 1 = W 1 ⋅ x + b 1 = [ 0.5 0.2 0.3 0.7 ] ⋅ [ 1 2 ] + [ 0.1 0.2 ] = [ 1.0 1.9 ] z_1 = W_1 \cdot x + b_1 = \begin{bmatrix} 0.5 & 0.2 \\ 0.3 & 0.7 \end{bmatrix} \cdot \begin{bmatrix} 1 \\ 2 \end{bmatrix} + \begin{bmatrix} 0.1 \\ 0.2 \end{bmatrix} = \begin{bmatrix} 1.0 \\ 1.9 \end{bmatrix} z1=W1x+b1=[0.50.30.20.7][12]+[0.10.2]=[1.01.9]
2.
a 1 = ReLU ( z 1 ) = ReLU ( [ 1.0 1.9 ] ) = [ 1.0 1.9 ] a_1 = \text{ReLU}(z_1) = \text{ReLU}(\begin{bmatrix} 1.0 \\ 1.9 \end{bmatrix}) = \begin{bmatrix} 1.0 \\ 1.9 \end{bmatrix} a1=ReLU(z1)=ReLU([1.01.9])=[1.01.9]
3.
z 2 = W 2 T ⋅ a 1 + b 2 = [ 0.6 0.9 ] T ⋅ [ 1.0 1.9 ] + 0.3 = 2.46 z_2 = W_2^T \cdot a_1 + b_2 = \begin{bmatrix} 0.6 \\ 0.9 \end{bmatrix}^T \cdot \begin{bmatrix} 1.0 \\ 1.9 \end{bmatrix} + 0.3 = 2.46 z2=W2Ta1+b2=[0.60.9]T[1.01.9]+0.3=2.46
4.
y ^ = z 2 = 2.46 \hat{y} = z_2 = 2.46 y^=z2=2.46
5.
L = 1 2 ( 2.46 − 3 ) 2 = 0.1458 L = \frac{1}{2} (2.46 - 3)^2 = 0.1458 L=21(2.463)2=0.1458

反向传播:
1.
∂ L ∂ z 2 = 2.46 − 3 = − 0.54 \frac{\partial L}{\partial z_2} = 2.46 - 3 = -0.54 z2L=2.463=0.54

  1. ∂ L ∂ W 2 = [ − 0.54 ] ⋅ [ 1.0 1.9 ] = [ − 0.54 ⋅ 1.0 − 0.54 ⋅ 1.9 ] = [ − 0.54 − 1.026 ] \frac{\partial L}{\partial W_2} = \begin{bmatrix} -0.54 \end{bmatrix} \cdot \begin{bmatrix} 1.0 \\ 1.9 \end{bmatrix} = \begin{bmatrix} -0.54 \cdot 1.0 \\ -0.54 \cdot 1.9 \end{bmatrix} = \begin{bmatrix} -0.54 \\ -1.026 \end{bmatrix} W2L=[0.54][1.01.9]=[0.541.00.541.9]=[0.541.026]
    ∂ L ∂ b 2 = − 0.54 \frac{\partial L}{\partial b_2} = -0.54 b2L=0.54

  2. ∂ L ∂ a 1 = [ 0.6 0.9 ] ⋅ − 0.54 = [ − 0.324 − 0.486 ] \frac{\partial L}{\partial a_1} = \begin{bmatrix} 0.6 \\ 0.9 \end{bmatrix} \cdot -0.54 = \begin{bmatrix} -0.324 \\ -0.486 \end{bmatrix} a1L=[0.60.9]0.54=[0.3240.486]
    ∂ L ∂ z 1 = ∂ L ∂ a 1 ⋅ ReLU ′ ( z 1 ) = [ − 0.324 − 0.486 ] ⋅ [ 1 1 ] = [ − 0.324 − 0.486 ] \frac{\partial L}{\partial z_1} = \frac{\partial L}{\partial a_1} \cdot \text{ReLU}'(z_1) = \begin{bmatrix} -0.324 \\ -0.486 \end{bmatrix} \cdot \begin{bmatrix} 1 \\ 1 \end{bmatrix} = \begin{bmatrix} -0.324 \\ -0.486 \end{bmatrix} z1L=a1LReLU(z1)=[0.3240.486][11]=[0.3240.486]

  3. ∂ L ∂ W 1 = ∂ L ∂ z 1 ⋅ x T = [ − 0.324 − 0.486 ] ⋅ [ 1 2 ] T = [ − 0.324 − 0.648 − 0.486 − 0.972 ] \frac{\partial L}{\partial W_1} = \frac{\partial L}{\partial z_1} \cdot x^T = \begin{bmatrix} -0.324 \\ -0.486 \end{bmatrix} \cdot \begin{bmatrix} 1 & 2 \end{bmatrix}^T = \begin{bmatrix} -0.324 & -0.648 \\ -0.486 & -0.972 \end{bmatrix} W1L=z1LxT=[0.3240.486][12]T=[0.3240.4860.6480.972]
    ∂ L ∂ b 1 = [ − 0.324 − 0.486 ] \frac{\partial L}{\partial b_1} = \begin{bmatrix} -0.324 \\ -0.486 \end{bmatrix} b1L=[0.3240.486]

从上述示例可以看到,每层的梯度依赖于上一层的激活值和当前层的损失梯度。梯度的传递通过链式法则一步步向前传播,从最初的损失函数计算开始,直到最终的输入层的权重和偏置。


http://www.ppmy.cn/server/144572.html

相关文章

打开智能识别API接口时代

近年来,随着人工智能的迅猛发展,智能识别技术逐渐成为了各行各业的热门话题。无论是在金融领域的身份证信息识别,还是在电商领域的收货信息识别,智能识别技术的应用都得到了广泛推广和应用。为了满足市场的需求,挖数据…

C++ 中的模板特化和偏特化

模版特化是为特定的模版参数类型提供提供专门的实现。 当需要为特定类型的模板参数提供不同的实现时,可以使用模板特化。模板特化允许你为特定的模板参数类型编写专门的代码,而不是使用通用的模板代码。 例如,对于一个通用的模板函数&#…

【设计模式】【创建型模式(Creational Patterns)】之单例模式

单例模式是一种常用的创建型设计模式,其目的是确保一个类只有一个实例,并提供一个全局访问点。 单例模式的原理 单例模式的核心在于控制类的实例化过程,通常通过以下方式实现: 私有化构造函数,防止外部直接实例化。…

以3D数字人AI产品赋能教育培训人才发展,魔珐科技亮相AI+教育创新与人才发展大会

11月20日,北京中关村国际创新中心迎来了“AI教育创新与人才发展大会暨首届北京数字人才发展大会”的盛大启幕。此次大会汇聚了培训、教育、科技、人才领域的专家学者、行业领袖及企业代表,共同探讨人工智能技术在教育培训领域的革新应用与数字人才培养体…

C语言练级->##__VA_ARGS__(可变参数)的用法

有什么用? 通常__VA_ARGS__用于宏定义,其中关于日志宏需要用的,printf 等支持可变参数的函数的宏封装。 首先我们先知道这个__VA_ARGS__的英文全称是“Variadic Arguments” 叫可变参数。说到可变参数学过C语言的朋友们应该都会想到printf&…

“LLM是否是泡沫”

目录 “LLM是否是泡沫” 培养自己鉴别论文价值的能力、复现开源项目的能力、debug 代码的能力 llm 是生产力工具 多去找实习,读再多的论文,刷再多的技术文章,也不如一次 debug 多机通讯报错带来的认知深刻 一、LLM领域的发展与挑战 二、…

【LeetCode每日一题】——746.使用最小花费爬楼梯

文章目录 一【题目类别】二【题目难度】三【题目编号】四【题目描述】五【题目示例】六【题目提示】七【解题思路】八【时空频度】九【代码实现】十【提交结果】 一【题目类别】 数组 二【题目难度】 简单 三【题目编号】 746.使用最小花费爬楼梯 四【题目描述】 给你一…

node读取execl或写入execl数据保存

nodejs 使用 exceljs 库读取 execl 或写入 execl 数据后保存文件 安装库 exceljs npm i exceljs 读取execl const exceljs require(exceljs)const workbook new exceljs.Workbook() await workbook.xlsx.readFile(test.xlsx) // 读取第一个工作表 const worksheet workbo…