56 门控循环单元(GRU)_by《李沐:动手学深度学习v2》pytorch版

news/2024/12/21 22:23:12/

系列文章目录


文章目录

  • 系列文章目录
  • 门控循环单元(GRU)
    • 门控隐状态
      • 重置门和更新门
      • 候选隐状态
      • 隐状态
    • 从零开始实现
      • 初始化模型参数
      • 定义模型
      • 训练与预测
    • 简洁实现
    • 小结
    • 练习


门控循环单元(GRU)

之前我们讨论了如何在循环神经网络中计算梯度,以及矩阵连续乘积可以导致梯度消失或梯度爆炸的问题。
下面我们简单思考一下这种梯度异常在实践中的意义:

  • 我们可能会遇到这样的情况:早期观测值对预测所有未来观测值具有非常重要的意义。
    考虑一个极端情况,其中第一个观测值包含一个校验和,目标是在序列的末尾辨别校验和是否正确。
    在这种情况下,第一个词元的影响至关重要。
    我们希望有某些机制能够在一个记忆元里存储重要的早期信息。
    如果没有这样的机制,我们将不得不给这个观测值指定一个非常大的梯度,因为它会影响所有后续的观测值。
  • 我们可能会遇到这样的情况:一些词元没有相关的观测值。
    例如,在对网页内容进行情感分析时,可能有一些辅助HTML代码与网页传达的情绪无关。
    我们希望有一些机制来跳过隐状态表示中的此类词元。
  • 我们可能会遇到这样的情况:序列的各个部分之间存在逻辑中断。
    例如,书的章节之间可能会有过渡存在,或者证券的熊市和牛市之间可能会有过渡存在。
    在这种情况下,最好有一种方法来重置我们的内部状态表示。

在学术界已经提出了许多方法来解决这类问题。
其中最早的方法是"长短期记忆"(long-short-term memory,LSTM),我们将在之后讨论。
门控循环单元(gated recurrent unit,GRU)是一个稍微简化的变体,通常能够提供同等的效果,并且计算的速度明显更快。
由于门控循环单元更简单,我们从它开始解读。

门控隐状态

门控循环单元与普通的循环神经网络之间的关键区别在于:
前者支持隐状态的门控。
这意味着模型有专门的机制来确定应该何时更新隐状态,以及应该何时重置隐状态。
这些机制是可学习的,并且能够解决了上面列出的问题。
例如,如果第一个词元非常重要,模型将学会在第一次观测之后不更新隐状态。
同样,模型也可以学会跳过不相关的临时观测。
最后,模型还将学会在需要的时候重置隐状态。
下面我们将详细讨论各类门控。

重置门和更新门

我们首先介绍重置门(reset gate)和更新门(update gate)。
我们把它们设计成 ( 0 , 1 ) (0, 1) (0,1)区间中的向量,这样我们就可以进行凸组合。(“凸组合”是一个数学概念,主要用于优化和几何领域。它指的是在给定的点(或向量)集合中,通过线性组合得到的点,其中每个点的系数都是非负的,并且所有系数的和为1。)
重置门允许我们控制“可能还想记住”的过去状态的数量;
更新门将允许我们控制新状态中有多少个是旧状态的副本。

我们从构造这些门控开始。
下图描述了门控循环单元中的重置门和更新门的输入,输入是由当前时间步的输入和前一时间步的隐状态给出。
两个门的输出是由使用sigmoid激活函数的两个全连接层给出。
在这里插入图片描述

我们来看一下门控循环单元的数学表达。对于给定的时间步 t t t,假设输入是一个小批量 X t ∈ R n × d \mathbf{X}_t \in \mathbb{R}^{n \times d} XtRn×d(样本个数 n n n,输入个数 d d d),上一个时间步的隐状态是 H t − 1 ∈ R n × h \mathbf{H}_{t-1} \in \mathbb{R}^{n \times h} Ht1Rn×h(隐藏单元个数 h h h)。
那么,重置门 R t ∈ R n × h \mathbf{R}_t \in \mathbb{R}^{n \times h} RtRn×h和更新门 Z t ∈ R n × h \mathbf{Z}_t \in \mathbb{R}^{n \times h} ZtRn×h的计算如下所示:

R t = σ ( X t W x r + H t − 1 W h r + b r ) , Z t = σ ( X t W x z + H t − 1 W h z + b z ) , \begin{aligned} \mathbf{R}_t = \sigma(\mathbf{X}_t \mathbf{W}_{xr} + \mathbf{H}_{t-1} \mathbf{W}_{hr} + \mathbf{b}_r),\\ \mathbf{Z}_t = \sigma(\mathbf{X}_t \mathbf{W}_{xz} + \mathbf{H}_{t-1} \mathbf{W}_{hz} + \mathbf{b}_z), \end{aligned} Rt=σ(XtWxr+Ht1Whr+br),Zt=σ(XtWxz+Ht1Whz+bz),

其中 W x r , W x z ∈ R d × h \mathbf{W}_{xr}, \mathbf{W}_{xz} \in \mathbb{R}^{d \times h} Wxr,WxzRd×h W h r , W h z ∈ R h × h \mathbf{W}_{hr}, \mathbf{W}_{hz} \in \mathbb{R}^{h \times h} Whr,WhzRh×h是权重参数, b r , b z ∈ R 1 × h \mathbf{b}_r, \mathbf{b}_z \in \mathbb{R}^{1 \times h} br,bzR1×h是偏置参数。
请注意,在求和过程中会触发广播机制,我们使用sigmoid函数(将输入值转换到区间 ( 0 , 1 ) (0, 1) (0,1)

候选隐状态

接下来,让我们将重置门 R t \mathbf{R}_t Rt
与 常规隐状态更新机制集成,得到在时间步 t t t候选隐状态(candidate hidden state) H ~ t ∈ R n × h \tilde{\mathbf{H}}_t \in \mathbb{R}^{n \times h} H~tRn×h

H ~ t = tanh ⁡ ( X t W x h + ( R t ⊙ H t − 1 ) W h h + b h ) , \tilde{\mathbf{H}}_t = \tanh(\mathbf{X}_t \mathbf{W}_{xh} + \left(\mathbf{R}_t \odot \mathbf{H}_{t-1}\right) \mathbf{W}_{hh} + \mathbf{b}_h), H~t=tanh(XtWxh+(RtHt1)Whh+bh),
:eqlabel:gru_tilde_H

其中 W x h ∈ R d × h \mathbf{W}_{xh} \in \mathbb{R}^{d \times h} WxhRd×h W h h ∈ R h × h \mathbf{W}_{hh} \in \mathbb{R}^{h \times h} WhhRh×h是权重参数, b h ∈ R 1 × h \mathbf{b}_h \in \mathbb{R}^{1 \times h} bhR1×h是偏置项,符号 ⊙ \odot 是Hadamard积(按元素乘积)运算符。
在这里,我们使用tanh非线性激活函数来确保候选隐状态中的值保持在区间 ( − 1 , 1 ) (-1, 1) (1,1)中。

与 一般的RNN相比, :eqref:gru_tilde_H中的 R t \mathbf{R}_t Rt H t − 1 \mathbf{H}_{t-1} Ht1的元素相乘可以减少以往状态的影响。
每当重置门 R t \mathbf{R}_t Rt中的项接近 1 1 1时,我们恢复一个如 :eqref:rnn_h_with_state中的普通的循环神经网络。
对于重置门 R t \mathbf{R}_t Rt中所有接近 0 0 0的项,候选隐状态是以 X t \mathbf{X}_t Xt作为输入的多层感知机的结果。
因此,任何预先存在的隐状态都会被重置为默认值。

:numref:fig_gru_2说明了应用重置门之后的计算流程。

在这里插入图片描述label:fig_gru_2

隐状态

上述的计算结果只是候选隐状态,我们仍然需要结合更新门 Z t \mathbf{Z}_t Zt的效果。
这一步确定新的隐状态 H t ∈ R n × h \mathbf{H}_t \in \mathbb{R}^{n \times h} HtRn×h在多大程度上来自旧的状态 H t − 1 \mathbf{H}_{t-1} Ht1和新的候选状态 H ~ t \tilde{\mathbf{H}}_t H~t。更新门 Z t \mathbf{Z}_t Zt仅需要在 H t − 1 \mathbf{H}_{t-1} Ht1 H ~ t \tilde{\mathbf{H}}_t H~t之间进行按元素的凸组合就可以实现这个目标。这就得出了门控循环单元的最终更新公式:

H t = Z t ⊙ H t − 1 + ( 1 − Z t ) ⊙ H ~ t . \mathbf{H}_t = \mathbf{Z}_t \odot \mathbf{H}_{t-1} + (1 - \mathbf{Z}_t) \odot \tilde{\mathbf{H}}_t. Ht=ZtHt1+(1Zt)H~t.

每当更新门 Z t \mathbf{Z}_t Zt接近 1 1 1时,模型就倾向只保留旧状态。
此时,来自 X t \mathbf{X}_t Xt的信息基本上被忽略,从而有效地跳过了依赖链条中的时间步 t t t
相反,当 Z t \mathbf{Z}_t Zt接近 0 0 0时,新的隐状态 H t \mathbf{H}_t Ht就会接近候选隐状态 H ~ t \tilde{\mathbf{H}}_t H~t
这些设计可以帮助我们处理循环神经网络中的梯度消失问题,并更好地捕获时间步距离很长的序列的依赖关系。
例如,如果整个子序列的所有时间步的更新门都接近于 1 1 1,则无论序列的长度如何,在序列起始时间步的旧隐状态都将很容易保留并传递到序列结束。

下图说明了更新门起作用后的计算流。

在这里插入图片描述label:fig_gru_3

总之,门控循环单元具有以下两个显著特征:

  • 重置门有助于捕获序列中的短期依赖关系;
  • 更新门有助于捕获序列中的长期依赖关系。

从零开始实现

为了更好地理解门控循环单元模型,我们从零开始实现它。
首先,我们读取时间机器数据集:

import torch
from torch import nn
from d2l import torch as d2lbatch_size, num_steps = 32, 35
train_iter, vocab = d2l.load_data_time_machine(batch_size, num_steps)

初始化模型参数

下一步是初始化模型参数。
我们从标准差为 0.01 0.01 0.01的高斯分布中提取权重,
并将偏置项设为 0 0 0,超参数num_hiddens定义隐藏单元的数量,
实例化与更新门、重置门、候选隐状态和输出层相关的所有权重和偏置。

def get_params(vocab_size, num_hiddens, device):num_inputs = num_outputs = vocab_sizedef normal(shape):return torch.randn(size=shape, device=device)*0.01def three():return (normal((num_inputs, num_hiddens)),normal((num_hiddens, num_hiddens)),torch.zeros(num_hiddens, device=device))W_xz, W_hz, b_z = three()  # 更新门参数W_xr, W_hr, b_r = three()  # 重置门参数W_xh, W_hh, b_h = three()  # 候选隐状态参数# 输出层参数W_hq = normal((num_hiddens, num_outputs))b_q = torch.zeros(num_outputs, device=device)# 附加梯度params = [W_xz, W_hz, b_z, W_xr, W_hr, b_r, W_xh, W_hh, b_h, W_hq, b_q]for param in params:param.requires_grad_(True)return params

定义模型

现在我们将[定义隐状态的初始化函数]init_gru_state
与从零实现RNN中定义的init_rnn_state函数一样,此函数返回一个形状为(批量大小,隐藏单元个数)的张量,张量的值全部为零。

def init_gru_state(batch_size, num_hiddens, device):return (torch.zeros((batch_size, num_hiddens), device=device), )

现在我们准备[定义门控循环单元模型],
模型的架构与基本的循环神经网络单元是相同的,只是权重更新公式更为复杂。
下面代码中的 @ @ @与torch.mm作用相同。

def gru(inputs, state, params):W_xz, W_hz, b_z, W_xr, W_hr, b_r, W_xh, W_hh, b_h, W_hq, b_q = paramsH, = stateoutputs = []for X in inputs:Z = torch.sigmoid((X @ W_xz) + (H @ W_hz) + b_z)R = torch.sigmoid((X @ W_xr) + (H @ W_hr) + b_r)H_tilda = torch.tanh((X @ W_xh) + ((R * H) @ W_hh) + b_h) # 注意这之前的H都是上一时刻,这时的H才是这一时刻的HH = Z * H + (1 - Z) * H_tildaY = H @ W_hq + b_qoutputs.append(Y)return torch.cat(outputs, dim=0), (H,)

训练与预测

训练和预测的工作方式与从零实现RNN完全相同。
训练结束后,我们分别打印输出训练集的困惑度,以及前缀“time traveler”和“traveler”的预测序列上的困惑度。

vocab_size, num_hiddens, device = len(vocab), 256, d2l.try_gpu()
num_epochs, lr = 500, 1
model = d2l.RNNModelScratch(len(vocab), num_hiddens, device, get_params,init_gru_state, gru)
d2l.train_ch8(model, train_iter, vocab, lr, num_epochs, device)
perplexity 1.1, 38401.1 tokens/sec on cuda:0
time travelleryou can show black is white by argument said filby
travelleryou can show black is white by argument said filby<Figure size 350x250 with 1 Axes>

在这里插入图片描述

简洁实现

高级API包含了前文介绍的所有配置细节,所以我们可以直接实例化门控循环单元模型。
这段代码的运行速度要快得多,因为它使用的是编译好的运算符而不是Python来处理之前阐述的许多细节。

num_inputs = vocab_size
gru_layer = nn.GRU(num_inputs, num_hiddens)
model = d2l.RNNModel(gru_layer, len(vocab))
model = model.to(device)
d2l.train_ch8(model, train_iter, vocab, lr, num_epochs, device)
perplexity 1.0, 389565.2 tokens/sec on cuda:0
time travelleryou can show black is white by argument said filby
traveller with a slight accession ofcheerfulness really thi<Figure size 350x250 with 1 Axes>

在这里插入图片描述

小结

  • 门控循环神经网络可以更好地捕获时间步距离很长的序列上的依赖关系。
  • 重置门有助于捕获序列中的短期依赖关系。
  • 更新门有助于捕获序列中的长期依赖关系。
  • 重置门打开时,门控循环单元包含基本循环神经网络;更新门打开时,门控循环单元可以跳过子序列。

练习

  1. 假设我们只想使用时间步 t ′ t' t的输入来预测时间步 t > t ′ t > t' t>t的输出。对于每个时间步,重置门和更新门的最佳值是什么?
  2. 调整和分析超参数对运行时间、困惑度和输出顺序的影响。
  3. 比较rnn.RNNrnn.GRU的不同实现对运行时间、困惑度和输出字符串的影响。
  4. 如果仅仅实现门控循环单元的一部分,例如,只有一个重置门或一个更新门会怎样?

http://www.ppmy.cn/news/1534081.html

相关文章

HTML流光爱心

文章目录 序号目录1HTML满屏跳动的爱心&#xff08;可写字&#xff09;2HTML五彩缤纷的爱心3HTML满屏漂浮爱心4HTML情人节快乐5HTML蓝色爱心射线6HTML跳动的爱心&#xff08;简易版&#xff09;7HTML粒子爱心8HTML蓝色动态爱心9HTML跳动的爱心&#xff08;双心版&#xff09;1…

Maya没有Arnold材质球

MAYA 没有Arnold材质球_哔哩哔哩_bilibili

学习docker第二弹------基本命令[帮助启动类命令、镜像命令、容器命令]

docker目录 前言基本命令帮助启动类命令停止docker服务查看docker状态启动docker重启docker开机启动docker查看概要信息查看总体帮助文档查看命令帮助文档 镜像命令查看所有的镜像 -a查看镜像ID -q在仓库里面查找redis拉取镜像查看容器/镜像/数据卷所占内存删除一个镜像删除多个…

第四周做题总结_数据结构_栈与应用

id:144 A. 前驱后继–双向链表&#xff08;线性结构&#xff09; 题目描述 在双向链表中&#xff0c;A有一个指针指向了后继节点B&#xff0c;同时&#xff0c;B又有一个指向前驱节点A的指针。这样不仅能从链表头节点的位置遍历整个链表所有节点&#xff0c;也能从链表尾节点…

Tensorflow2.0

Tensorflow2.0 有深度学习基础的建议直接看class3 class1 介绍 人工智能3学派 行为主义:基于控制论&#xff0c;构建感知-动作控制系统。(控制论&#xff0c;如平衡、行走、避障等自适应控制系统) 符号主义:基于算数逻辑表达式&#xff0c;求解问题时先把问题描述为表达式…

【需求分析】软件系统需求设计报告,需求分析报告,需求总结报告(原件PPT)

第1章 序言 第2章 引言 2.1 项目概述 2.1.1 项目背景 2.1.2 项目目标 2.2 编写目的 2.3 文档约定 2.4 预期读者及阅读建议 第3章 技术要求 3.1 软件开发要求 3.1.1 接口要求 3.1.2 系统专有技术 3.1.3 查询功能 3.1.4 数据安全 3.1.5 可靠性要求 3.1.6 稳定性要求 3.1.7 安全性…

2024年三个月网络安全(黑客技术)入门教程

&#x1f91f; 基于入门网络安全/黑客打造的&#xff1a;&#x1f449;黑客&网络安全入门&进阶学习资源包 前言 什么是网络安全 网络安全可以基于攻击和防御视角来分类&#xff0c;我们经常听到的 “红队”、“渗透测试” 等就是研究攻击技术&#xff0c;而“蓝队”、…

基于SpringBoot+Vue+MySQL的考勤管理系统

系统展示 管理员界面 用户界面 系统背景 随着企业规模的扩大和管理的精细化&#xff0c;传统的考勤方式已经无法满足现代企业的需求。纸质签到、人工统计不仅效率低下&#xff0c;还容易出错。因此&#xff0c;开发一套基于SpringBootVueMySQL的考勤管理系统显得尤为重要。该系…