深度学习学习经验——长短期记忆网络(LSTM)

news/2024/9/18 14:59:45/ 标签: 深度学习, 学习, lstm

长短期记忆网络(LSTM)

长短期记忆网络(LSTM,Long Short-Term Memory)是一种特殊的循环神经网络(RNN),专为解决 RNN 中长期依赖问题而设计。LSTM 引入了三个门和一个细胞状态(cell state),以便更好地控制信息的流动,确保网络能够记住长期的依赖关系。我们将通过一个逐步深入的案例来讲解 LSTM 的内部结构和工作机制,并使用 PyTorch 实现一个 LSTM 模型。

1. 问题描述

与我上一篇讲解RNN的文章相似,假设我们仍然要预测未来的天气情况,但是这次数据包含更多的噪声,且我们希望模型能够更好地“记住”一段时间内的趋势信息。这时,LSTM 比普通 RNN 更适合这种任务,因为它能够通过门控机制更精确地控制信息流动。

2. LSTM 的基本原理

LSTM 通过引入 输入门(Input Gate)、遗忘门(Forget Gate)、输出门(Output Gate)和 细胞状态(Cell State)来管理信息的记忆和遗忘。

2.1 细胞状态(Cell State)

细胞状态是贯穿整个 LSTM 的主线,类似于一个“传送带”,它能够允许信息在序列中几乎不受干扰地传递下去。LSTM 通过少量的线性相互作用,轻松地让信息在其上流动,只有少数的部分会被门结构所改变。

2.2 遗忘门(Forget Gate)

遗忘门决定了细胞状态中哪些信息需要丢弃。这个门读取当前的输入 ( x t (x_t (xt) 和上一时刻的隐藏状态 ( h t − 1 (h_{t-1} (ht1),并输出一个介于 0 和 1 之间的值,其中 0 代表完全忘记,1 代表完全保留。

f t = σ ( W f ⋅ [ h t − 1 , x t ] + b f ) f_t = \sigma(W_f \cdot [h_{t-1}, x_t] + b_f) ft=σ(Wf[ht1,xt]+bf)

forget_gate = nn.Sigmoid()
f_t = forget_gate(torch.matmul(W_f, torch.cat((h_{t-1}, x_t), dim=1)) + b_f)
2.3 输入门(Input Gate)

输入门控制哪些新的信息将被写入细胞状态。这个过程分为两步:首先,输入门生成一个控制写入的信号,然后通过一个 tanh 层创建一个新的候选细胞状态。

i t = σ ( W i ⋅ [ h t − 1 , x t ] + b i ) i_t = \sigma(W_i \cdot [h_{t-1}, x_t] + b_i) it=σ(Wi[ht1,xt]+bi)

C ~ t = tanh ⁡ ( W C ⋅ [ h t − 1 , x t ] + b C ) \tilde{C}_t = \tanh(W_C \cdot [h_{t-1}, x_t] + b_C) C~t=tanh(WC[ht1,xt]+bC)

input_gate = nn.Sigmoid()
i_t = input_gate(torch.matmul(W_i, torch.cat((h_{t-1}, x_t), dim=1)) + b_i)candidate_layer = nn.Tanh()
C_tilda = candidate_layer(torch.matmul(W_C, torch.cat((h_{t-1}, x_t), dim=1)) + b_C)
2.4 更新细胞状态(Cell State)

在更新细胞状态时,我们结合了遗忘门的结果 ( f t (f_t (ft) 和输入门的结果 ( i t (i_t (it) 以及候选细胞状态 ( C ~ t (\tilde{C}_t (C~t):

C t = f t ∗ C t − 1 + i t ∗ C ~ t C_t = f_t * C_{t-1} + i_t * \tilde{C}_t Ct=ftCt1+itC~t

C_t = f_t * C_{t-1} + i_t * C_tilda
2.5 输出门(Output Gate)

输出门决定了当前的隐藏状态 ( h t (h_t (ht) 是什么,同时输出门还控制了有多少细胞状态信息能够传递到下一层。首先,输出门生成一个信号,然后结合当前的细胞状态生成新的隐藏状态。

o t = σ ( W o ⋅ [ h t − 1 , x t ] + b o ) o_t = \sigma(W_o \cdot [h_{t-1}, x_t] + b_o) ot=σ(Wo[ht1,xt]+bo)

h t = o t ∗ tanh ⁡ ( C t ) h_t = o_t * \tanh(C_t) ht=ottanh(Ct)

output_gate = nn.Sigmoid()
o_t = output_gate(torch.matmul(W_o, torch.cat((h_{t-1}, x_t), dim=1)) + b_o)h_t = o_t * torch.tanh(C_t)

3. 使用 PyTorch 实现 LSTM

现在我们使用 PyTorch 实现一个 LSTM 模型,并用它来预测天气。

3.1 导入必要的库
import torch
import torch.nn as nn
import numpy as np
3.2 定义 LSTM 模型
class WeatherLSTM(nn.Module):def __init__(self, input_size, hidden_size, output_size):super(WeatherLSTM, self).__init__()self.hidden_size = hidden_sizeself.lstm = nn.LSTM(input_size, hidden_size, batch_first=True)self.fc = nn.Linear(hidden_size, output_size)def forward(self, x):h_0 = torch.zeros(1, x.size(0), self.hidden_size)  # 初始化隐藏状态c_0 = torch.zeros(1, x.size(0), self.hidden_size)  # 初始化细胞状态out, (h_n, c_n) = self.lstm(x, (h_0, c_0))  # 计算所有时间步长的输出out = self.fc(out[:, -1, :])  # 取最后一个时间步长的输出并通过线性层return out

在这个模型中,LSTM 层代替了 RNN 层,它自动处理了前面介绍的遗忘门、输入门、输出门和细胞状态的更新。

3.3 准备数据
# 生成示例数据
data = np.array([[30, 31, 32, 33, 34],[32, 33, 34, 35, 36],[35, 36, 37, 38, 39]],dtype=np.float32)labels = np.array([35, 37, 40], dtype=np.float32)  # 对应的目标值# 转换为 PyTorch 张量
data = torch.from_numpy(data).unsqueeze(-1)  # 添加特征维度
labels = torch.from_numpy(labels).unsqueeze(-1)
3.4 训练模型
# 定义超参数
input_size = 1
hidden_size = 10
output_size = 1
num_epochs = 100
learning_rate = 0.01# 实例化模型
model = WeatherLSTM(input_size, hidden_size, output_size)# 定义损失函数和优化器
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)# 训练模型
for epoch in range(num_epochs):model.train()outputs = model(data)loss = criterion(outputs, labels)optimizer.zero_grad()loss.backward()optimizer.step()if (epoch+1) % 10 == 0:print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}')
3.5 使用模型进行预测
model.eval()
with torch.no_grad():test_input = torch.tensor([[36, 37, 38, 39, 40]], dtype=torch.float32).unsqueeze(-1)predicted_temperature = model(test_input)print(f"预测的温度: {predicted_temperature.item():.2f}")

4. 完整代码

import torch
import torch.nn as nn
import numpy as npclass WeatherLSTM(nn.Module):def __init__(self, input_size, hidden_size, output_size):super(WeatherLSTM, self).__init__()self.hidden_size = hidden_sizeself.lstm = nn.LSTM(input_size, hidden_size, batch_first=True)self.fc = nn.Linear(hidden_size, output_size)def forward(self, x):h_0 = torch.zeros(1, x.size(0), self.hidden_size)c_0 = torch.zeros(1, x.size(0), self.hidden_size)out, (h_n, c_n) = self.lstm(x, (h_0, c_0))out = self.fc(out[:, -1, :])return out# 生成示例数据
data = np.array([[30, 31, 32, 33, 34],[32, 33, 34, 35, 36],[35, 36, 37, 38, 39]],dtype=np.float32)labels = np.array([35, 37, 40], dtype=np.float32)data = torch.from_numpy(data).unsqueeze(-1)
labels = torch.from_numpy(labels).unsqueeze(-1)# 定义超参数
input_size = 1
hidden_size = 10
output_size = 1
num_epochs = 100
learning_rate = 0.01# 实例化模型
model = WeatherLSTM(input_size, hidden_size, output_size)# 定义损失函数和优化器
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)# 训练模型
for epoch in range(num_epochs):model.train()outputs = model(data)loss = criterion(outputs, labels)optimizer.zero_grad()loss.backward()optimizer.step()if (epoch+1) % 10 == 0:print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}')# 进行预测
model.eval()
with torch.no_grad():test_input = torch.tensor([[36, 37, 38, 39, 40]], dtype=torch.float32).unsqueeze(-1)predicted_temperature = model(test_input)print(f"预测的温度: {predicted_temperature.item():.2f}")

5. 总结

LSTM 通过引入遗忘门、输入门、输出门和细胞状态,能够有效解决 RNN 中的长期依赖问题,使得模型可以更好地在序列数据中保留重要信息并进行预测。


http://www.ppmy.cn/news/1516731.html

相关文章

Spring DI 数据类型——构造注入

首先新建项目&#xff0c;可参考 初识 IDEA 、模拟三层--控制层、业务层和数据访问层 一、spring 环境搭建 &#xff08;一&#xff09;pom.xml 导相关坐标 <?xml version"1.0" encoding"UTF-8"?> <project xmlns"http://maven.apache.…

FPGA时序约束

FPGA时序约束 目录 FPGA时序约束 前言1、建立和保持时间1.1 建立时间1.2 保持时间 2 时序路径2 具体约束2.1 IO约束2.1.1 管脚约束2.1.2 延迟约束2.1.3 虚拟时钟 2.2周期约束&#xff0c;FPGA内部的时序路径2.2.1 主时钟2.2.2 衍生时钟2.2.3 主时钟之间的相互关系2.2.4 使用BUF…

S3协议分片上传(minio)

文章目录 前言一、minio 为例二、使用步骤1.引入库2.读入数据总结前言 目前文件存储一般采用obs存储,也就是对象存储 比较流行的有: minio 阿里云 华为云 阿里云 腾讯云 七牛云 百度云 ,对于贫穷的我来说,当然选择免费开源的minio了,但是他们有一个统一的标准也就是S3协议,相…

一文读懂 DDD领域驱动设计

DDD&#xff08;Domain-Driven Design&#xff0c;领域驱动设计&#xff09;是一种软件开发方法&#xff0c;它强调软件系统设计应该以问题领域为中心&#xff0c;而不是技术实现为主导。DDD通过一系列手段如统一语言、业务抽象、领域划分和领域建模等来控制软件复杂度&#xf…

【Rust光年纪】Rust多媒体处理库全面比较:探索安全高效的多媒体处理利器

多媒体处理不再困扰&#xff1a;解锁Rust语言下的六大多媒体利器 前言 随着Rust语言的快速发展&#xff0c;越来越多的多媒体处理库和工具集开始出现&#xff0c;为开发人员提供了丰富的选择。本文将对几个用于Rust语言的多媒体处理库进行介绍&#xff0c;并对它们的核心功能…

【算法】希尔排序、计数排序、桶排序、基数排序

1 希尔排序 2 计数排序 3 桶排序 4 基数排序 1 希尔排序 """ 希尔排序&#xff08;Shell Sort&#xff09;是一种插入排序算法的改进版本&#xff0c;得名于其发明者Donald Shell。 它通过比较一定间隔的元素来进行排序&#xff0c;以减少数据移动的次数&#…

Jmeter进行http接口测试

&#x1f345; 点击文末小卡片&#xff0c;免费获取软件测试全套资料&#xff0c;资料在手&#xff0c;涨薪更快 1、jmeter-http接口测试脚本 jmeter进行http接口测试的主要步骤&#xff08;1.添加线程组 2.添加http请求 3.在http请求中写入接口的URL&#xff0c;路径&#xf…

【HarmonyOS NEXT星河版开发实战】天气查询APP

目录 前言 界面效果展示 首页 添加和删除 界面构建讲解 1. 获取所需数据 2. 在编译器中准备数据 3. index页面代码讲解 3.1 导入模块&#xff1a; 3.2 定义组件&#xff1a; 3.3 定义状态变量: 3.4 定义Tabs控制器: 3.5 定义按钮样式&#xff1a; 3.6 页面显示时触发…

Java核心API——Collection集合的工具类Collections

集合的排序 int类型的排序 * 集合的排序 * java.util.Collections是集合的工具类&#xff0c;提供了很多static方法用于操作集合 * 其中提供了一个名为sort的方法&#xff0c;可以对List集合进行自然排序(从小到大) List<Integer> list new ArrayList<>();Rand…

96页PPT集团战略解码会工具与操作流程

德勤集团在战略解码过程中通常会用到以下一些具体工具&#xff1a; 一、平衡计分卡&#xff08;Balanced Scorecard&#xff09; 财务维度&#xff1a; 明确关键财务指标&#xff0c;如营业收入、利润、投资回报率等。你可以通过分析历史财务数据和行业趋势&#xff0c;确定…

设计模式24-命令模式

设计模式24-命令模式 写在前面行为变化模式 命令模式的动机定义与结构定义结构 C 代码推导优缺点应用场景总结补充函数对象&#xff08;Functors&#xff09;定义具体例子示例&#xff1a;使用函数对象进行自定义排序代码说明输出结果具体应用 优缺点应用场景 命令模式&#xf…

鸿蒙位置服务

位置服务 1、首先申请权限 在module.json5文件下申请位置权限 "requestPermissions": [{"name": "ohos.permission.LOCATION", // 权限名称,为系统已定义的权限"reason": "$string:location_reason", // 申请权限的原因,…

windows 核心编程第五章:演示作业的使用及获取统计信息

演示作业的使用及获取统计信息 演示作业的使用及获取统计信息 文章目录 演示作业的使用及获取统计信息演示作业的使用及获取统计信息 演示作业的使用及获取统计信息 /* 演示作业的使用及获取统计信息 */#include <stdio.h> #include <Windows.h> #include <tc…

HBase原理和操作

目录 一、HBase在Zookeeper中的存储元数据信息集群状态信息 二、HBase的操作Web Console命令行操作 三、HBase中数据的保存过程 一、HBase在Zookeeper中的存储 元数据信息 HBase的元数据信息是HBase集群运行所必需的关键数据&#xff0c;它存储在Zookeeper的"/hbase&quo…

ARM32开发——(七)GD32F4串口引脚_复用功能_查询

1. GD32F4串口引脚查询 TX RX CK CTS RTS USART0 PA9,PA15,PB6 PA10,PB3,PB7 PA8 PA11 PA12 USART1 PA2,PD5 PA3,PD6 PA4,PD7 PA0,PD3 PA1,PD4 USART2 PB10,PC10,PD8 PB11,PC5,PD9 PB12,PC12,PD10 PB13,PD11 PB14,PD12 UART3 PA0,PC10 PA1,PC11 …

kafka 入门

kafka 有分区和副本的概念&#xff0c;partition 3 表示有3个分区&#xff0c;replication 2 表示有2个副本 通过 --describe --topic test命令可以知道 test这个 主题的分区和副本情况&#xff0c;途中的replicas 表示 其他副本分区的情况&#xff0c;如第一条&#xff0c;t…

【运筹学】【数据结构】【经典算法】最小生成树问题及贪心算法设计

1 知识回顾 我们已经讲过最小生成树问题的基础知识&#xff0c;我们现在想要利用贪心算法解决该问题。我们再来回顾一下最小生成树问题和贪心算法的基础知识。 最小生成树问题就是从某个图中找出总权重最小的生成树。 贪心算法是一种算法设计范式&#xff0c;每一步都选…

深度学习学习经验——全连接神经网络(FCNN)

什么是全连接神经网络&#xff1f; 全连接神经网络&#xff08;FCNN&#xff09;是最基础的神经网络结构&#xff0c;它由多个神经元组成&#xff0c;这些神经元按照层级顺序连接在一起。每一层的每个神经元都与前一层的每个神经元连接。 想象你在参加一个盛大的晚会&#xf…

Vue中的this.$emit()方法详解【父子组件传值常用】

​在Vue中&#xff0c;this.$emit()方法用于触发自定义事件。它是Vue实例的一个方法&#xff0c;可以在组件内部使用。 使用this.$emit()方法&#xff0c;你可以向父组件发送自定义事件&#xff0c;并传递数据给父组件。父组件可以通过监听这个自定义事件来执行相应的逻辑。 …

问界M7 Pro这招太狠了,直击理想L6/L7要害

文 | AUTO芯球 作者 | 雷慢 李想的理想估计要失眠了&#xff0c;为什么啊&#xff1f; 前有L6悬架薄如铁片被曝光&#xff0c;被车主们骂了个狗血淋头&#xff0c; 现在又来个问界M7 Pro版&#xff0c; 24.98万的后驱智驾版就上华为ADS主视觉智驾了&#xff0c; 两个后驱&…