GRU模块:nn.GRU层的输出state与output

server/2024/12/22 0:40:50/

       在 GRU(Gated Recurrent Unit)中,outputstate 都是由 GRU 层的循环计算产生的,它们之间有直接的关系。state 实际上是 output 中最后一个时间步的隐藏状态。

GRU 的基本公式

GRU 的核心计算包括更新门(update gate)和重置门(reset gate),以及候选隐藏状态(candidate hidden state)。数学表达式如下:

  1. 更新门 \( z_t \): \[ z_t = \sigma(W_z \cdot h_{t-1} + U_z \cdot x_t) \]
       其中,\( \sigma \) 是sigmoid 函数,\( W_z \) 和 \( U_z \) 分别是对应于隐藏状态和输入的权重矩阵,\( h_{t-1} \) 是上一个时间步的隐藏状态,\( x_t \) 是当前时间步的输入。

  2. 重置门 \( r_t \):
       \[ r_t = \sigma(W_r \cdot h_{t-1} + U_r \cdot x_t) \]
       \( W_r \) 和 \( U_r \) 是更新门中定义的相似权重矩阵。

  3. 候选隐藏状态 \( \tilde{h}_t \):
       \[ \tilde{h}_t = \tanh(W \cdot r_t \odot h_{t-1} + U \cdot x_t) \]
       这里,\( \tanh \) 是激活函数,\( \odot \) 表示元素乘法(Hadamard product),\( W \) 和 \( U \) 是隐藏状态的权重矩阵。

  4. 最终隐藏状态 \( h_t \):
       \[ h_t = (1 - z_t) \odot h_{t-1} + z_t \odot \tilde{h}_t \]

output 和 state 的关系

  • output:在 GRU 中,output 包含了序列中每个时间步的隐藏状态。具体来说,对于每个时间步 \( t \),output 的第 \( t \) 个元素就是该时间步的隐藏状态 \( h_t \)。

  • state:state 是 GRU 层最后一层的隐藏状态,也就是 output 中最后一个时间步的隐藏状态 \( h_{T-1} \),其中 \( T \) 是序列的长度。

数学表达式

如果我们用 \( O \) 表示 output,\( S \) 表示 state,\( T \) 表示时间步的总数,那么:

\[ O = [h_0, h_1, ..., h_{T-1}] \]
\[ S = h_{T-1} \]

因此,state 实际上是 output 中最后一个元素,即 \( S = O[T-1] \)。

在 PyTorch 中,output 和 state 都是由 GRU 层的 `forward` 方法计算得到的。`output` 是一个三维张量,包含了序列中每个时间步的隐藏状态,而 `state` 是一个二维张量,仅包含最后一个时间步的隐藏状态。

代码示例

class Seq2SeqEncoder(d2l.Encoder):
"""⽤于序列到序列学习的循环神经⽹络编码器"""def __init__(self, vocab_size, embed_size, num_hiddens, num_layers,
dropout=0, **kwargs):super(Seq2SeqEncoder, self).__init__(**kwargs)# 嵌⼊层self.embedding = nn.Embedding(vocab_size, embed_size)self.rnn = nn.GRU(embed_size, num_hiddens, num_layers,dropout=dropout)def forward(self, X, *args):# 输出'X'的形状:(batch_size,num_steps,embed_size)X = self.embedding(X)# 在循环神经⽹络模型中,第⼀个轴对应于时间步X = X.permute(1, 0, 2)# 如果未提及状态,则默认为0output, state = self.rnn(X)# output的形状:(num_steps,batch_size,num_hiddens)# state的形状:(num_layers,batch_size,num_hiddens)return output, state

output:在完成所有时间步后,最后⼀层的隐状态的输出output是⼀个张量(output由编码器的循环层返回),其形状为(时间步数,批量⼤⼩,隐藏单元数)。

state:最后⼀个时间步的多层隐状态是state的形状是(隐藏层的数量,批量⼤⼩, 隐藏单元的数量)。


http://www.ppmy.cn/server/35604.html

相关文章

PHP介绍

🐌博主主页:🐌​倔强的大蜗牛🐌​ 📚专栏分类:PHP❤️感谢大家点赞👍收藏⭐评论✍️ 目录 一、PHP是什么? 二、 PHP 文件是什么? 三、PHP能做什么? 四、P…

pubg绝地求生吃鸡加速器推荐 pubg吃鸡加速器免费低延迟

《绝地求生》(PUBG) 是由韩国Krafton工作室开发的一款战术竞技型射击类沙盒游戏。2022年1月12日,该游戏于主机和PC上可免费下载游玩。绝地求生已经上线了好久的时间,仍然保持的很好的热度,无时无刻都在涌入新手玩家。游戏有多张地图可供玩家选…

Vulnhub-DIGITALWORLD.LOCAL: VENGEANCE渗透

文章目录 前言1、靶机ip配置2、渗透目标3、渗透概括 开始实战一、信息获取二、smb下载线索三、制作字典四、爆破压缩包密码五、线索分析六、提权!!! Vulnhub靶机:DIGITALWORLD.LOCAL: VENGEANCE ( digitalworld.local: VENGEANCE …

LeetCode:两数之和

文章收录于LeetCode专栏 LeetCode地址 两数之和 给定一个整数数组nums和一个整数目标值target,请你在该数组中找出和为目标值的那两个整数,并返回它们的数组下标。   你可以假设每种输入只会对应一个答案。但是数组中同一个元素在答案里不能重复出现。…

DevicData-P-XXXXXX勒索病毒有什么特点,如何恢复重要数据?

DevicData-P-XXXXXXXX这种新型勒索病毒有什么特点? DevicData-P-XXXXXXXX勒索病毒的特点主要包括以下几点: 文件加密:该病毒会搜索并加密用户计算机上的重要文件,如文档、图片、视频等,使其无法正常打开。这是通过强…

021、Python+fastapi,第一个Python项目走向第21步:ubuntu 24.04 docker 安装mysql8集群、redis集群(二)

系列文章目录 pythonvue3fastapiai 学习_浪淘沙jkp的博客-CSDN博客https://blog.csdn.net/jiangkp/category_12623996.html 前言 安装redis 我会以三种方式安装,在5月4号修改完成 第一、直接最简单安装,适用于测试环境玩玩 第二、conf配置安装 第三…

HTML 官网进行移动端和 PC 端适配

使用响应式布局:确保你的 HTML 结构使用了响应式布局,即页面的元素能够根据不同设备的屏幕大小和分辨率进行自适应调整。 媒体查询:在 CSS 中使用媒体查询来针对不同的设备条件应用特定的样式。例如,你可以针对手机、平板和桌面设…

Go语言的包管理工具go mod与之前的GOPATH有什么区别?

在深入探讨Go语言的包管理工具go mod与之前的GOPATH之间的区别之前,我们首先需要理解这两个概念各自的作用和背景。 GOPATH时代 在Go语言早期版本中,GOPATH是一个非常重要的环境变量。它告诉Go工具链在哪里查找你的Go代码、第三方库以及编译后的二进制…