跟着问题学12——GRU详解

embedded/2024/9/24 7:32:25/

   

1 GRU

1. 什么是GRU

GRU(Gate Recurrent Unit)是循环神经网络(Recurrent Neural Network, RNN)的一种。和LSTM(Long-Short Term Memory)一样,也是为了解决长期记忆

和反向传播中的梯度等问题而提出来的。

GRU和LSTM在很多情况下实际表现上相差无几,那么为什么我们要使用新人GRU(2014年提出)而不是相对经受了更多考验的LSTM(1997提出)呢。

下图1-1引用论文中的一段话来说明GRU的优势所在。

图1-1 R-NET: MACHINE READING COMPREHENSION WITH SELF-MATCHING NETWORKS(2017)

简单译文:我们在我们的实验中选择GRU是因为它的实验效果与LSTM相似,但是更易于计算。

相比LSTM,使用GRU能够达到相当的效果,并且相比之下更容易进行训练,能够很大程度上提高训练效率,因此很多时候会更倾向于使用GRU。

OK,那么为什么说GRU更容易进行训练呢,下面开始介绍一下GRU的内部结构。

RNN的缺陷——长期依赖的问题 (The Problem of Long-Term Dependencies)

RNNs的一个吸引人的地方是,他们可能能够将以前的信息与现在的任务联系起来,例如使用视频前面的几帧画面可能有助于理解现在这一帧的画面。如果RNNs能做到这一点,它们将非常有用。但它们能不能有效,这得视情况而定。

但也有一些情况,我们需要更多的上下文。试着预测课文中的最后一个单词“我在法国长大……我说一口流利的法语。”“最近的信息显示,下一个单词很可能是一种语言的名字,但如果我们想缩小范围,我们需要更早的法语语境。”相关信息与需要它的点之间的差距完全有可能变得非常大。

不幸的是,随着这种差距的扩大,RNNs无法学会连接信息。

从理论上讲,RNN绝对有能力处理这种“长期依赖性”。人们可以为他们精心选择参数,以解决这种形式的问题。遗憾的是,在实践中,RNN似乎无法学习它们。 Hochreiter (1991) [German]和 Bengio, et al. (1994)等人对此问题进行了深入探讨。 他们发现了一些RNN很难做到的根本原因。【http://ai.dinfo.unifi.it/paolo//ps/tnn-94-gradient.pdf

http://people.idsia.ch/~juergen/SeppHochreiter1991ThesisAdvisorSchmidhuber.pdf】

幸运的是,LSTM没有这个问题!

总体结构框架

多层感知机(线性连接层)

从特征角度考虑:输入特征是n*1的单维向量(这也是为什么linear层前要把所有特征层展平的原因),然后根据隐含层神经元的数量m将前层输入的特征用m*1的单维向量进行表示(对特征进行了提取变换),单个隐含层的神经元数量就代表网络参数,可以设置多个隐含层;最终根据输出层的神经元数量y输出y*1的单维向量。

卷积神经网络

 从特征角度考虑:输入特征是width*height*channel的张量, 然后根据通道channel的数量c会有c个卷积核将前层输入的特征用k*k*c的张量进行卷积(对特征进行了提取变换,k为卷积核尺寸),卷积核的大小和数量k*k*c就代表网络参数,可以设置多个隐含层;每一个channel都代表提取某方面的一种特征,该特征用width*height的二维张量表示,不同特征层之间是相互独立的(可以进行融合)。最终根据场景的需要设置后面的输出。

RNN&LSTM&GRU

从特征角度考虑:输入特征是T_seq*feature_size的张量(T_seq代表序列长度),每个时刻t可以类似于CNN的通道channel,只是时刻t的特征(channel)是和t之前时刻的特征(channel)相关联的,所以H_t是由X_t和H_t-1共同作为输入决定的,每个时刻t的特征表示是用feature_size*1的单维向量表示的,每个隐状态H_t类似于一个channel,特征的表示是用hidden_size*1的单维向量表示的,H_t的channel总数就是输入的序列长度,所以一个隐含层是T_seq*hidden_size的张量,如图中所示,同一个隐含层不同时刻的参数W_ih和W_hh是共享的;隐含层可以有num_layers个(图中只有1个)

以t时刻具体阐述一下:

X_t是t时刻的输入,是一个feature_size*1的向量

W_ih是输入层到隐藏层的权重矩阵

H_t是t时刻的隐藏层的值,是一个hidden_size*1的向量

W_hh是上一时刻的隐藏层的值传入到下一时刻的隐藏层时的权重矩阵

Ot是t时刻RNN网络的输出

从上右图中可以看出这个RNN网络在t时刻接受了输入Xt之后,隐藏层的值是St,输出的值是Ot。但是从结构图中我们可以发现St并不单单只是由Xt决定,还与t-1时刻的隐藏层的值St-1有关。

2.1 GRU的输入输出结构

GRU的输入输出结构与普通的RNN是一样的。有一个当前的输入xt,和上一个节点传递下来的隐状态(hidden state)ht-1 ,这个隐状态包含了之前节点的相关信息。结合xt和 ht-1,GRU会得到当前隐藏节点的输出yt 和传递给下一个节点的隐状态 ht。

图2-1 GRU的输入输出结构

那么,GRU到底有什么特别之处呢?下面来对它的内部结构进行分析!

2.2 GRU的内部结构

首先,我们先通过上一个传输下来的状态

和当前节点的输入 来获取两个门控状态。如下图2-2所示,其中 控制重置的门控(reset gate),

为控制更新的门控(update gate)。

Tips:

为sigmoid函数,通过这个函数可以将数据变换为0-1范围内的数值,从而来充当门控信号。

得到门控信号之后,首先使用重置门控来得到“重置”之后的数据 ,再将 与输入 进行拼接,再通过一个tanh激活函数来将数据放缩到-1~1的范围内。即得到如下图2-3所示的 。

这里的 主要是包含了当前输入的 数据。有针对性地对 添加到当前的隐藏状态,相当于”记忆了当前时刻的状态“。

图2-4中的 是Hadamard Product,也就是操作矩阵中对应的元素相乘,因此要求两个相乘矩阵是同型的。 则代表进行矩阵加法操作。

最后介绍GRU最关键的一个步骤,我们可以称之为”更新记忆“阶段。

在这个阶段,我们同时进行了遗忘了记忆两个步骤。我们使用了先前得到的更新门控 (update gate)。

首先再次强调一下,门控信号(这里的 )的范围为0~1。门控信号越接近1,代表”记忆“下来的数据越多;而越接近0则代表”遗忘“的越多。

概括来说,GRU将遗忘和输入门组合成一个“更新门”。“它还融合了细胞状态和隐藏状态,并做出一些其他的改变。得到的模型比标准LSTM模型更简单,并且越来越受欢迎。

参考资料

https://zhuanlan.zhihu.com/p/32481747

https://speech.ee.ntu.edu.tw/~tlkagk/courses/MLDS_2018/Lecture/Seq%20(v2).pdf

https://www.bilibili.com/video/BV1jm4y1Q7uh/?spm_id_from=333.788&vd_source=cf7630d31a6ad93edecfb6c5d361c659


http://www.ppmy.cn/embedded/115977.html

相关文章

掌上高考爬虫逆向分析

目标网站 aHR0cHM6Ly93d3cuZ2Fva2FvLmNuL3NjaG9vbC9zZWFyY2g/cmVjb21zY2hwcm9wPSVFNSU4QyVCQiVFOCU4RCVBRg 一、抓包分析 二、逆向分析 搜索定位加密参数 本地生成代码 var CryptoJS require(crypto-js) var crypto require(crypto);f "D23ABC#56"function v(t…

SpringBoot3核心特性-核心原理

目录 传送门前言一、事件和监听器1、生命周期监听2、事件触发时机 二、自动配置原理1、入门理解1.1、自动配置流程1.2、SPI机制1.3、功能开关 2、进阶理解2.1、 SpringBootApplication2.2、 完整启动加载流程 三、自定义starter1、业务代码2、基本抽取3、使用EnableXxx机制4、完…

Java | Leetcode Java题解之第433题最小基因变化

题目&#xff1a; 题解&#xff1a; class Solution {public int minMutation(String start, String end, String[] bank) {int m start.length();int n bank.length;List<Integer>[] adj new List[n];for (int i 0; i < n; i) {adj[i] new ArrayList<Intege…

Flutter鸿蒙化环境配置(windows)

Flutter鸿蒙化环境配置&#xff08;windows&#xff09; 参考资料Window配置Flutter的鸿蒙化环境下载配置环境变量HarmonyOS的环境变量配置配置Flutter的环境变量Flutter doctor -v 检测的问题flutter_flutter仓库地址的警告问题Fliutter doctor –v 报错[!] Android Studio (v…

cccccccccccc

目录 1. ls指令 2. pwd指令 3. cd指令 4. whoami 5. clear指令 6. touch指令 7. mkdir指令(重要) 8. rmdir指令与rm指令(重要) 8.1 rmdir指令 8.2 rm指令 9. man指令(重要) 10. cp指令(重要) 11. mv指令(重要) 12. nano指令 13. cat指令 14. echo指令 重定向 1…

携手阿里云CEN:共创SD-WAN融合广域网

在9月19日举行的阿里云云栖大会上&#xff0c;犀思云作为SD-WAN领域的杰出代表及阿里云的SD-WAN重要合作伙伴&#xff0c;携手阿里云共同推出了创新的企业上云方案——Fusion WAN智连阿里云解决方案。这一创新方案不仅彰显了犀思云在SD-WAN技术领域的深厚积累&#xff0c;更体现…

秒表【JavaScript】

这个代码实现了一个基本的功能性秒表。 实现功能&#xff1a; 代码&#xff1a; <!DOCTYPE html> <html lang"zh"><head><meta charset"UTF-8"><meta name"viewport" content"widthdevice-width, initial-sc…

git学习报告

文章目录 git学习报告如何配置vscode终端安装PowerShell安装 Microsoft.Powershell.Preview使用 git的使用关于团队合作 git指令本地命令&#xff1a;云端指令 git学习报告 如何配置vscode 安装powershell调教window终端&#xff0c;使其像Linux一样&#xff0c;通过Linux命令…