什么是门控循环单元?

embedded/2025/2/3 1:38:08/

一、概念

        门控循环单元(Gated Recurrent Unit,GRU)是一种改进的循环神经网络(RNN),由Cho等人在2014年提出。GRU是LSTM的简化版本,通过减少门的数量和简化结构,保留了LSTM的长时间依赖捕捉能力,同时提高了计算效率。GRU通过引入两个门(重置门和更新门)来控制信息的流动。与LSTM不同,GRU没有单独的细胞状态,而是将隐藏状态直接作为信息传递的载体,因此结构更简单,计算效率更高。

二、核心算法

        令x_{t}为时间步 t 的输入向量,h_{t-1}为前一个时间步的隐藏状态向量,h_{t}为当前时间步的隐藏状态向量,r_{t}为当前时间步的重置门向量,z_{t}为当前时间步的更新门向量,\bar{h_{t}}为当前时间步的候选隐藏状态向量,W_{r},W_{z},W_{h}分别为各门的权重矩阵,b_{r},b_{z},b_{h}为偏置向量,\sigma为sigmoid激活函数,tanh为tanh激活函数,*为元素级乘法。

1、重置门

        重置门控制前一个时间步的隐藏状态对当前时间步的影响。通过sigmoid激活函数,重置门的输出在0到1之间,表示前一个隐藏状态元素被保留的比例。

r_{t} = \sigma(W_{r} \cdot \left [ h_{t-1}, x_{t} \right ] + b_{r})

2、更新门

        更新门控制前一个时间步的隐藏状态和当前时间步的候选隐藏状态的混合比例。通过sigmoid激活函数,更新门的输出在0到1之间,表示前一个隐藏状态元素被保留的比例。

z_{t} = \sigma(W_{z} \cdot \left [ h_{t-1}, x_{t} \right ] + b_{z})

3、候选隐藏状态

        候选隐藏状态结合当前输入和前一个时间步的隐藏状态生成。重置门的输出与前一个隐藏状态相乘,表示保留的旧信息。然后与当前输入一起通过tanh激活函数生成候选隐藏状态。

\bar{h_{t}} = tanh(W_{h} \cdot \left [ r_{t} \ast h_{t-1}, x_{t} \right ] + b_{h})

4、隐藏状态更新

        隐藏状态结合更新门的结果进行更新。更新门的输出与前一个隐藏状态相乘,表示保留的旧信息。更新门的补数与候选隐藏状态相乘,表示写入的新信息。两者相加得到当前时间步的隐藏状态。

h_{t} = (1-z_{t}) \ast h_{t-1} + z_{t} \ast \bar{h_{t}}

三、python实现

python">import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt# 设置随机种子
torch.manual_seed(0)
np.random.seed(0)# 生成正弦波数据
timesteps = 1000
sin_wave = np.array([np.sin(2 * np.pi * i / timesteps) for i in range(timesteps)])# 创建数据集
def create_dataset(data, time_step=1):dataX, dataY = [], []for i in range(len(data) - time_step - 1):a = data[i:(i + time_step)]dataX.append(a)dataY.append(data[i + time_step])return np.array(dataX), np.array(dataY)time_step = 10
X, y = create_dataset(sin_wave, time_step)# 数据预处理
X = X.reshape(X.shape[0], time_step, 1)
y = y.reshape(-1, 1)# 转换为Tensor
X = torch.tensor(X, dtype=torch.float32)
y = torch.tensor(y, dtype=torch.float32)# 划分训练集和测试集
train_size = int(len(X) * 0.7)
test_size = len(X) - train_size
trainX, testX = X[:train_size], X[train_size:]
trainY, testY = y[:train_size], y[train_size:]# 定义RNN模型
class GRUModel(nn.Module):def __init__(self, input_size, hidden_size, output_size):super(GRUModel, self).__init__()self.hidden_size = hidden_sizeself.gru = nn.GRU(input_size, hidden_size, batch_first=True)self.fc = nn.Linear(hidden_size, output_size)def forward(self, x):h0 = torch.zeros(1, x.size(0), self.hidden_size)out, _ = self.gru(x, h0)out = self.fc(out[:, -1, :])return outinput_size = 1
hidden_size = 50
output_size = 1
model = GRUModel(input_size, hidden_size, output_size)# 定义损失函数和优化器
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)# 训练模型
num_epochs = 50
for epoch in range(num_epochs):model.train()optimizer.zero_grad()outputs = model(trainX)loss = criterion(outputs, trainY)loss.backward()optimizer.step()if (epoch + 1) % 10 == 0:print(f'Epoch [{epoch + 1}/{num_epochs}], Loss: {loss.item():.4f}')# 预测
model.eval()
train_predict = model(trainX)
test_predict = model(testX)
train_predict = train_predict.detach().numpy()
test_predict = test_predict.detach().numpy()# 绘制结果
plt.figure(figsize=(10, 6))
plt.plot(sin_wave, label='Original Data')
plt.plot(np.arange(time_step, time_step + len(train_predict)), train_predict, label='Training Predict')
plt.plot(np.arange(time_step + len(train_predict), time_step + len(train_predict) + len(test_predict)), test_predict, label='Test Predict')
plt.legend()
plt.show()

四、总结

        GRU的结构比LSTM更简单,只有两个门(重置门和更新门),没有单独的细胞状态。这使得GRU的计算复杂度较低,训练和推理速度更快。通过引入重置门和更新门,GRU也有效地解决了标准RNN在处理长序列时的梯度消失和梯度爆炸问题。然而,在需要更精细的门控制和信息流动的任务中,LSTM的性能可能优于GRU。因此在我们实际的建模过程中,可以根据数据特点选择合适的RNN系列模型,并没有哪个模型能在所有任务中都具有优势。


http://www.ppmy.cn/embedded/159057.html

相关文章

本地部署DeepSeek教程(Mac版本)

第一步、下载 Ollama 官网地址:Ollama 点击 Download 下载 我这里是 macOS 环境 以 macOS 环境为主 下载完成后是一个压缩包,双击解压之后移到应用程序: 打开后会提示你到命令行中运行一下命令,附上截图: 若遇…

一文学会HTML编程之视频+图文详解详析

前言 本文涵盖了html的所有核心知识点,因为篇幅非常长,故题主将本教程分为七个层次,师傅们结合自身的时间安排,灵活调整即可。 视频教程 哔哩哔哩(B站)搜索框中输入“uid3546393096489381”即可 用户&a…

RK3588平台开发系列讲解(ARM篇)ARM64底层中断处理

文章目录 一、异常级别二、异常分类2.1、同步异常2.2、异步异常三、中断向量表沉淀、分享、成长,让自己和他人都能有所收获!😄 一、异常级别 ARM64处理器确实定义了4个异常级别(Exception Levels, EL),分别是EL0到EL3。这些级别用于管理处理器的特权级别和权限,级别越高…

AD中如何画插件的封装

AD中如何画插件的封装 一、说明 元器件的封装在大类上只分为贴片器件和插件器件,目前随着贴片机的大量应用,插件器件的使用已经减少了很多;但是对于很多小批量小规模的电路板生产来说,插件器件对于电路板的生产入门要求更低,用一般的焊接技工就可以完成,并且插件器件也更…

设计模式Python版 组合模式

文章目录 前言一、组合模式二、组合模式实现方式三、组合模式示例四、组合模式在Django中的应用 前言 GOF设计模式分三大类: 创建型模式:关注对象的创建过程,包括单例模式、简单工厂模式、工厂方法模式、抽象工厂模式、原型模式和建造者模式…

【C++】特殊类设计

目录 一、请设计一个类,不能被拷贝二、请设计一个类,只能在堆上创建对象三、请设计一个类,只能在栈上创建对象四、请设计一个类,不能被继承五、请设计一个类,只能创建一个对象(单例模式)5.1 饿汉模式5.2 懒汉模式 结尾…

四.3 Redis 五大数据类型/结构的详细说明/详细使用( hash 哈希表数据类型详解和使用)

四.3 Redis 五大数据类型/结构的详细说明/详细使用( hash 哈希表数据类型详解和使用) 文章目录 四.3 Redis 五大数据类型/结构的详细说明/详细使用( hash 哈希表数据类型详解和使用)2.hash 哈希表常用指令(详细讲解说明)2.1 hset …

javascript-es6 (一)

作用域(scope) 规定了变量能够被访问的“范围”,离开了这个“范围”变量便不能被访问 局部作用域 函数作用域: 在函数内部声明的变量只能在函数内部被访问,外部无法直接访问 function getSum(){ //函数内部是函数作用…