第十五站:循环神经网络(RNN)与长短期记忆网络(LSTM)

server/2025/3/4 0:33:19/

1. 循环神经网络(RNN)概述

RNN 是一种非常适合处理序列数据的神经网络。与传统的前馈神经网络不同,RNN 具有一个 循环连接,它可以 记住 前一个时刻的信息,并将其传递到当前时刻。

RNN 的工作原理

  • 输入序列:RNN 接收一个序列的输入,比如时间序列数据、文本数据等。
  • 隐藏状态:RNN 的核心是其 隐藏状态,它存储了对输入序列历史的记忆。
  • 递归计算:在每一步,RNN 会计算当前时刻的隐藏状态,并将其传递到下一时刻的计算中。

RNN 的数学表示如下:
h t = σ ( W h h h t − 1 + W x h x t + b ) h_t = \sigma(W_{hh}h_{t-1} + W_{xh}x_t + b) ht=σ(Whhht1+Wxhxt+b)

  • h t h_t ht 是当前时刻的隐藏状态。
  • W h h W_{hh} Whh是隐藏状态到隐藏状态的权重。
  • W x h W_{xh} Wxh是输入到隐藏状态的权重。
  • x t x_t xt是当前时刻的输入。
  • b b b是偏置项。

2. RNN 的局限性

尽管 RNN 能够处理序列数据,但它存在 梯度消失和梯度爆炸问题。特别是在长序列上,RNN 很难保持长时间的依赖关系。

  • 梯度消失:在训练过程中,当梯度经过多次反向传播时,可能会变得非常小,导致网络无法有效学习长期依赖关系。
  • 梯度爆炸:相反,梯度也可能变得非常大,导致训练不稳定。

为了解决这些问题,我们引入了 长短期记忆网络(LSTM)

3. 长短期记忆网络(LSTM)

LSTM 是一种特殊的 RNN,它引入了 记忆单元门控机制,使得网络能够更好地学习和保持长期依赖。

LSTM 的工作原理

LSTM 通过 遗忘门、输入门和输出门 来控制信息的流动。

  • 遗忘门(Forget Gate):决定哪些信息需要丢弃。
  • 输入门(Input Gate):决定哪些信息需要存储到记忆单元中。
  • 输出门(Output Gate):决定从记忆单元中输出哪些信息。

LSTM 中的每个步骤计算如下:

  1. 遗忘门:决定当前隐藏状态中有多少信息需要被丢弃。
    f t = σ ( W f ⋅ [ h t − 1 , x t ] + b f ) f_t = \sigma(W_f \cdot [h_{t-1}, x_t] + b_f) ft=σ(Wf[ht1,xt]+bf)
  2. 输入门:决定当前输入中有多少信息需要被保存。
    i t = σ ( W i ⋅ [ h t − 1 , x t ] + b i ) i_t = \sigma(W_i \cdot [h_{t-1}, x_t] + b_i) it=σ(Wi[ht1,xt]+bi)
  3. 记忆单元更新:更新当前的记忆单元。
    C t = f t ⋅ C t − 1 + i t ⋅ C ~ t C_t = f_t \cdot C_{t-1} + i_t \cdot \tilde{C}_t Ct=ftCt1+itC~t
  4. 输出门:决定从记忆单元中输出哪些信息。
    o t = σ ( W o ⋅ [ h t − 1 , x t ] + b o ) o_t = \sigma(W_o \cdot [h_{t-1}, x_t] + b_o) ot=σ(Wo[ht1,xt]+bo)
  5. 隐藏状态更新:最终的隐藏状态。
    h t = o t ⋅ tanh ⁡ ( C t ) h_t = o_t \cdot \tanh(C_t) ht=ottanh(Ct)

4. LSTM 的优势

  • 长期依赖:LSTM 能够更好地捕捉长期依赖关系,解决了传统 RNN 的梯度消失问题。
  • 门控机制:通过遗忘门、输入门和输出门,LSTM 控制了信息的流动,避免了无用信息的积累。

5. LSTM 示例代码:

下面是一个使用 LSTM 进行时间序列预测的简单示例代码:

import torch
import torch.nn as nn
import torch.optim as optim# 定义 LSTM 网络结构
class LSTMModel(nn.Module):def __init__(self, input_size, hidden_size, output_size):super(LSTMModel, self).__init__()self.lstm = nn.LSTM(input_size, hidden_size)  # LSTM 层self.fc = nn.Linear(hidden_size, output_size)  # 全连接层def forward(self, x):# 初始化 LSTM 的隐藏状态和细胞状态h0 = torch.zeros(1, x.size(1), hidden_size).to(x.device)  # 隐藏状态 h0c0 = torch.zeros(1, x.size(1), hidden_size).to(x.device)  # 细胞状态 c0# LSTM 前向传播lstm_out, (hn, cn) = self.lstm(x, (h0, c0))# 使用最后一个时间步的输出进行预测out = self.fc(lstm_out[-1])  # lstm_out[-1] 形状为 (batch_size, hidden_size)return out# 输入参数
input_size = 1  # 输入特征维度
hidden_size = 64  # LSTM 隐藏层维度
output_size = 1  # 输出维度# 创建 LSTM 模型实例
model = LSTMModel(input_size, hidden_size, output_size)# 定义损失函数和优化器
criterion = nn.MSELoss()  # 均方误差损失函数
optimizer = optim.Adam(model.parameters(), lr=0.001)  # Adam 优化器# 假设我们有一个简单的时间序列数据
data = torch.randn(10, 100, 1)  # 形状为 (sequence_length, batch_size, input_size)
labels = torch.randn(100, 1)  # 目标值,形状为 (batch_size, output_size)# 训练循环
for epoch in range(100):model.train()  # 设置模型为训练模式optimizer.zero_grad()  # 清空梯度# 预测output = model(data)  # 前向传播# 计算损失loss = criterion(output, labels)  # 计算损失# 反向传播loss.backward()  # 计算梯度# 更新参数optimizer.step()  # 更新模型参数# 输出损失值if epoch % 10 == 0:print(f'Epoch {epoch + 1}, Loss: {loss.item()}')

关键点说明:

  1. LSTM 层:

    • self.lstm = nn.LSTM(input_size, hidden_size):定义了一个 LSTM 层,input_size 表示每个时间步的输入维度,hidden_size 是 LSTM 层的隐藏单元数量。
    • LSTM 网络有一个非常重要的特点,即它能够通过递归传递信息(记忆)来处理时间序列数据。
  2. 全连接层:

    • self.fc = nn.Linear(hidden_size, output_size):全连接层用于将 LSTM 的输出映射到最终的预测结果。
    • 在这里,我们将 hidden_size 的输出映射到 output_size,适用于回归任务。
  3. 前向传播:

    • lstm_out, (hn, cn) = self.lstm(x, (h0, c0)):将输入数据 x 传入 LSTM 层,并得到 LSTM 的输出和最后的隐藏状态 hn、细胞状态 cn
    • out = self.fc(lstm_out[-1]):选择 LSTM 输出序列的最后一个时间步的输出,传递给全连接层进行预测。
  4. 训练过程:

    • 清空梯度:每个 epoch 之前使用 optimizer.zero_grad() 清空之前计算的梯度。
    • 损失计算和反向传播:通过 criterion(output, labels) 计算损失,并通过 loss.backward() 进行反向传播来计算梯度。
    • 优化器更新optimizer.step() 用来更新模型的参数。

6. LSTM 在实际的生活中也有很多应用地方:

LSTM 广泛应用于 时间序列分析自然语言处理(NLP) 中:

  1. 时间序列预测:LSTM 可以用来预测股票价格、天气变化等序列数据。
  2. 文本生成和语言建模:LSTM 可用于生成文本或建模语言的上下文。
  3. 机器翻译:LSTM 用于翻译不同语言之间的句子。
  4. 语音识别:LSTM 可用于处理语音信号,并将其转换为文本。

结语:因为博主的一些原因,机器学习系列就更到这里,学到这里各位也应该对机器学习的基础有一定的了解,并能搭建属于自己的一个神经网络,并去进行调优,改进,并部署到实际的现实需求当中


http://www.ppmy.cn/server/172206.html

相关文章

设计模式Python版 观察者模式

文章目录 前言一、观察者模式二、观察者模式示例 前言 GOF设计模式分三大类: 创建型模式:关注对象的创建过程,包括单例模式、简单工厂模式、工厂方法模式、抽象工厂模式、原型模式和建造者模式。结构型模式:关注类和对象之间的组…

TCP 三次握手与四次挥手

TCP 三次握手与四次挥手知识总结 一、TCP 连接与断开的核心机制 1. 三次握手(建立连接) 目的: 建立客户端与服务端之间的双向传输通道,确保双方都能确认对方的接收和发送能力,为后续的数据传输奠定可靠基础。 流程…

【AIGC系列】3:Stable Diffusion模型原理介绍

AIGC系列博文: 【AIGC系列】1:自编码器(AutoEncoder, AE) 【AIGC系列】2:DALLE 2模型介绍(内含扩散模型介绍) 【AIGC系列】3:Stable Diffusion模型原理介绍 【AIGC系列】4&#xff1…

【网络安全 | 渗透测试】GraphQL精讲二:发现API漏洞

未经许可,不得转载。 推荐阅读:【网络安全 | 渗透测试】GraphQL精讲一:基础知识 文章目录 GraphQL API 漏洞寻找 GraphQL 端点通用查询常见的端点名称请求方法初步测试利用未清理的参数发现模式信息使用 introspection探测 introspection运行完整的 introspection 查询可视化…

阿里云服务器宝塔终端如何创建fastadmin插件

1. 进入宝塔终端 2. cd / 进入根目录 3. FastAdmin 可以通过命令行创建一个插件,首先我们将工作目录切换到我们的项目根目录,也就是think文件所在的目录。 cd /var/www/yoursite/ 4.然后我们在命令行输入 php think addon -a mydemo -c create …

目标检测——数据处理

1. Mosaic 数据增强 Mosaic 数据增强步骤: (1). 选择四个图像: 从数据集中随机选择四张图像。这四张图像是用来组合成一个新图像的基础。 (2) 确定拼接位置: 设计一个新的画布(输入size的2倍),在指定范围内找出一个随机点(如…

火绒终端安全管理系统V2.0网络防御功能介绍

网络防御是指通过一系列技术、策略和措施,保护网络系统、数据和资源免受未经授权的访问、攻击、破坏或泄露。 火绒终端安全管理系统:网络防御功能包含网络入侵拦截、横向渗透防护、对外攻击检测、僵尸网络防护、Web服务保护、暴破攻击防护、远程登录防护…

广东GZ033-任务E:数据可视化(15 分)-用柱状图展示销售金额最高的6 个月

广东GZ033-任务E:数据可视化(15 分) 用柱状图展示销售金额最高的6 个月 编写Vue 工程代码, 读取虚拟机bigdata-spark 的/opt/data 目录下的 supermarket_visualization.csv,用柱状图展示2024 年销售金额最高的6 个月&a…