GRU 和 LSTM 公式推导与矩阵变换过程图解

ops/2025/2/12 20:10:55/

GRU 和 LSTM 公式推导与矩阵变换过程图解

  • GRU
  • LSTM


本文的前置篇链接: 单向/双向,单层/多层RNN输入输出维度问题一次性解决


GRU

GRU(Gate Recurrent Unit)是循环神经网络(RNN)的一种,可以解决RNN中不能长期记忆和反向传播中的梯度等问题,与LSTM的作用类似,不过比LSTM简单,容易进行训练。

GRU的输入输出结构与普通的RNN是一样的:

在这里插入图片描述
GRU 内部结构有两个核心的门控状态,分别为r控制重置的门控(reset gate) , z控制更新的门控(update gate) ,这两个门控状态值通过上一时刻传递过来的 h t − 1 h_{t-1} ht1和当前节点输入 x t x_{t} xt结合sigmoid得出,具体公式如下:

在这里插入图片描述

下面我们更加细致来看一下两个门控状态值计算过程中的矩阵维度变换:

在这里插入图片描述

在这里插入图片描述

得到门控信号后,我们首先使用重置门r来得到重置后的 h t − 1 ′ = h t − 1 ⊙ r {h_{t-1}}' =h_{t-1}\odot r ht1=ht1r :

在这里插入图片描述

符号 ⊙ 通常表示 Hadamard 乘积(也称为元素-wise 乘积或 Schur 乘积)。Hadamard 乘积是两个矩阵的元素-wise 乘积,即两个矩阵的对应元素相乘,图中用*代替Hadamard 乘积,用x代替矩阵乘法运算

再将 h t − 1 ′ {h_{t-1}}' ht1与输入 x t x_{t} xt进行拼接,再通过一个tanh激活函数对数据进行非线性变换,得到 h ′ {h}' h, 具体计算过程如下图所示
在这里插入图片描述

更新记忆阶段,我们同时进行了遗忘记忆两个步骤。我们使用了先前得到的更新门控z,更新表达式如下:
h t = z ⊙ h t − 1 + ( 1 − z ) ⊙ h ′ h_{t} = z\odot h_{t-1} + (1-z) \odot {h}' ht=zht1+(1z)h
在这里插入图片描述

门控信号z的范围为0~1。门控信号越接近1,代表”记忆“下来的数据越多;而越接近0则代表”遗忘“的越多。

GRU很聪明的一点就在于,我们使用了同一个门控z就同时可以进行遗忘和选择记忆(LSTM则要使用多个门控)。

  • z ⊙ h t − 1 z\odot h_{t-1} zht1 表示对原本隐藏状态的选择性遗忘,这里的 z z z可以想象成遗忘门,忘记 h t − 1 h_{t-1} ht1中一些不重要的信息。
  • ( 1 − z ) ⊙ h ′ (1-z)\odot {h}' (1z)h 表示对包含当前节点信息的 h ′ {h}' h进行选择性记忆,这里的 ( 1 − z ) (1-z) (1z)同理会忘记 h ′ {h}' h中一些不重要的信息,或者看成是对 h ′ {h}' h中某些重要信息的筛选。
  • h t = z ⊙ h t − 1 + ( 1 − z ) ⊙ h ′ h_{t} = z\odot h_{t-1} + (1-z) \odot {h}' ht=zht1+(1z)h 总的来看就是忘记传递下来的 h t − 1 h_{t-1} ht1中的某些不重要信息,并加入当前节点输入信息中某些重要部分。

更新门 与 重置门 总结:

  • 重置门的作用是决定前一时间步的隐藏状态在多大程度上被忽略。当重置门的输出接近0时,网络倾向于“忘记”前一时间步的信息,仅依赖于当前输入;而当输出接近1时,前一时间步的信息将被更多地保留。
  • 更新门的作用是决定当前时间步的隐藏状态需要保留多少前一个时间步的信息。更新门的输出值介于0和1之间,值越大表示保留的过去信息越多,值越小则意味着更多地依赖于当前输入的信息。

把GRU所有流程放在一张图展示:
在这里插入图片描述

总的来说:

  • GRU输入输出的结构与普通的RNN相似,其中的内部思想与LSTM相似。
  • 与LSTM相比,GRU内部少了一个”门控“,参数比LSTM少,但是却也能够达到与LSTM相当的功能。考虑到硬件的计算能力和时间成本,因而很多时候我们也就会选择更加”实用“的GRU啦。

补充 ( 看完 L S T M 部分后,再看该补充说明 ) : 补充(看完LSTM部分后,再看该补充说明): 补充(看完LSTM部分后,再看该补充说明):

  • r (reset gate)实际上与他的名字有些不符合,因为我们仅仅使用它来获得了 h ′ {h}' h
  • GRU中的 h ′ {h}' h实际上可以看成对应于LSTM中的hidden state,而上一个节点传下的 h t − 1 h_{t-1} ht1则对应于LSTM中的cell state。
  • z z z对应的则是LSTM中的 z f z_{f} zf,那么 ( 1 − z ) (1-z) (1z)就可以看成是选择门 z i z_{i} zi

LSTM

长短期记忆(Long short-term memory, LSTM)是一种特殊的RNN,主要是为了解决长序列训练过程中的梯度消失和梯度爆炸问题。简单来说,就是相比普通的RNN,LSTM能够在更长的序列中有更好的表现。

LSTM 的输入输出结构与普通的RNN区别如下图所示:

在这里插入图片描述
相比于RNN只有一个传递状态 h t h_{t} ht,LSTM有两个传输状态,一个 c t c_{t} ct(cell state) ,和一个 h t h_{t} ht(hidden state)。

LSTM中的 c t c_{t} ct相当于RNN中的 h t h_{t} ht

下面我们来看一下LSTM内部结构,首先使用LSTM的当前输入 x t x_{t} xt和上一个状态传递下来的 h t − 1 h_{t-1} ht1计算得到四个状态值:

在这里插入图片描述
在这里插入图片描述

在这里插入图片描述

在这里插入图片描述
z i , z f , z o z_{i},z_{f},z{o} zi,zf,zo 通过sigmoid非线性变换,分别作为输入门控,遗忘门控和输出门控,而 z z z则是通过tanh非线性变换,作为输入数据。

LSTM内部主要有三个阶段:

  1. 选择记忆阶段: 该阶段会对输入的数据 z z z进行有选择性的记忆,而选择的门控信号由 z i z_{i} zi来进行控制。
    在这里插入图片描述

  2. 忘记阶段: 这个阶段是对上一个节点传入的输入进行选择性的忘记,通过 z f z_{f} zf作为忘记门控,来控制上一个状态 c t − 1 c_{t-1} ct1哪些需要留下,哪些需要忘记。

在这里插入图片描述

  1. 将上面两个阶段得到的结果相加,即可得到传输给下一个状态的 c t c_{t} ct

在这里插入图片描述

  1. 输出阶段: 这个阶段会决定哪些将会被当成当前状态的输出,主要是通过 z o z_{o} zo进行控制,并且还对上一个阶段得到的 c t c_{t} ct进行了tanh的非线性变换。

在这里插入图片描述

  1. 与普通RNN类似,输出的 y t y_{t} yt往往最终也是通过 h t h_{t} ht变换得到,如: y t = s i g m o i d ( W h t ) y_{t} = sigmoid(W h_{t}) yt=sigmoid(Wht)

把LSTM所有流程放在一张图展示:
在这里插入图片描述

总结:

  • 以上,就是LSTM的内部结构。通过门控状态来控制传输状态,记住需要长时间记忆的,忘记不重要的信息;而不像普通的RNN那样只能够“呆萌”地仅有一种记忆叠加方式。对很多需要“长期记忆”的任务来说,尤其好用。

  • 但也因为引入了很多内容,导致参数变多,也使得训练难度加大了很多。因此很多时候我们往往会使用效果和LSTM相当但参数更少的GRU来构建大训练量的模型。


http://www.ppmy.cn/ops/157845.html

相关文章

使用Docker + Ollama在Ubuntu中部署deepseek

1、安装docker 这里建议用docker来部署,方便简单 安装教程需要自己找详细的,会用到跳过 如果你没有安装 Docker,可以按照以下步骤安装: sudo apt update sudo apt install apt-transport-https ca-certificates curl software-p…

【大模型】AI 辅助编程操作实战使用详解

目录 一、前言 二、AI 编程介绍 2.1 AI 编程是什么 2.1.1 为什么需要AI辅助编程 2.2 AI 编程主要特点 2.3 AI编程底层核心技术 2.4 AI 编程核心应用场景 三、AI 代码辅助编程解决方案 3.1 AI 大模型平台 3.1.1 AI大模型平台代码生成优缺点 3.2 AI 编码插件 3.3 AI 编…

怎样确定网站访问速度出现问题是后台还是服务器造成的?

网站的访问速度会影响到用户的体验感,当网络过于卡顿或访问速度较慢时,会给用户带来不好的体验感,但是网站访问速度不仅会是后台造成影响的,也可能是服务器的原因,那么我们该如何分辨呢? 当网站使用了数据库…

单片机上SPI和IIC的区别

SPI(Serial Peripheral Interface)和IC(Inter-Integrated Circuit)是两种常用的嵌入式外设通信协议,它们各有优缺点,适用于不同的场景。以下是它们的详细对比: — 1. 基本概念 SPI&#xff0…

【DeepSeek】deepseek可视化部署

目录 1 -> 前文 2 -> 部署可视化界面 1 -> 前文 【DeepSeek】DeepSeek概述 | 本地部署deepseek 通过前文可以将deepseek部署到本地使用,可是每次都需要winR输入cmd调出命令行进入到命令模式,输入命令ollama run deepseek-r1:latest。体验很…

C++ ——从C到C++

1、C的学习方法 (1)C知识点概念内容比较多,需要反复复习 (2)偏理论,有的内容不理解,可以先背下来,后续可能会理解更深 (3)学好编程要多练习,简…

Spring常用注解和组件

引言 了解Spring常用注解的使用方式可以帮助我们更快速理解这个框架和其中的深度 注解 Configuration:表示该类是一个配置类,用于定义 Spring Bean。 EnableAutoConfiguration:启用 Spring Boot 的自动配置功能,让 Spring Boo…

Git 分布式版本控制工具使用教程

1.关于Git 1.1 什么是Git Git是一款免费、开源的分布式版本控制工具,由Linux创始人Linus Torvalds于2005年开发。它被设计用来处理从很小到非常大的项目,速度和效率都非常高。Git允许多个开发者几乎同时处理同一个项目而不会互相干扰,并且在…