深度学习学习经验——长短期记忆网络(LSTM)

embedded/2024/10/18 9:26:00/

长短期记忆网络(LSTM)

长短期记忆网络(LSTM,Long Short-Term Memory)是一种特殊的循环神经网络(RNN),专为解决 RNN 中长期依赖问题而设计。LSTM 引入了三个门和一个细胞状态(cell state),以便更好地控制信息的流动,确保网络能够记住长期的依赖关系。我们将通过一个逐步深入的案例来讲解 LSTM 的内部结构和工作机制,并使用 PyTorch 实现一个 LSTM 模型。

1. 问题描述

与我上一篇讲解RNN的文章相似,假设我们仍然要预测未来的天气情况,但是这次数据包含更多的噪声,且我们希望模型能够更好地“记住”一段时间内的趋势信息。这时,LSTM 比普通 RNN 更适合这种任务,因为它能够通过门控机制更精确地控制信息流动。

2. LSTM 的基本原理

LSTM 通过引入 输入门(Input Gate)、遗忘门(Forget Gate)、输出门(Output Gate)和 细胞状态(Cell State)来管理信息的记忆和遗忘。

2.1 细胞状态(Cell State)

细胞状态是贯穿整个 LSTM 的主线,类似于一个“传送带”,它能够允许信息在序列中几乎不受干扰地传递下去。LSTM 通过少量的线性相互作用,轻松地让信息在其上流动,只有少数的部分会被门结构所改变。

2.2 遗忘门(Forget Gate)

遗忘门决定了细胞状态中哪些信息需要丢弃。这个门读取当前的输入 ( x t (x_t (xt) 和上一时刻的隐藏状态 ( h t − 1 (h_{t-1} (ht1),并输出一个介于 0 和 1 之间的值,其中 0 代表完全忘记,1 代表完全保留。

f t = σ ( W f ⋅ [ h t − 1 , x t ] + b f ) f_t = \sigma(W_f \cdot [h_{t-1}, x_t] + b_f) ft=σ(Wf[ht1,xt]+bf)

forget_gate = nn.Sigmoid()
f_t = forget_gate(torch.matmul(W_f, torch.cat((h_{t-1}, x_t), dim=1)) + b_f)
2.3 输入门(Input Gate)

输入门控制哪些新的信息将被写入细胞状态。这个过程分为两步:首先,输入门生成一个控制写入的信号,然后通过一个 tanh 层创建一个新的候选细胞状态。

i t = σ ( W i ⋅ [ h t − 1 , x t ] + b i ) i_t = \sigma(W_i \cdot [h_{t-1}, x_t] + b_i) it=σ(Wi[ht1,xt]+bi)

C ~ t = tanh ⁡ ( W C ⋅ [ h t − 1 , x t ] + b C ) \tilde{C}_t = \tanh(W_C \cdot [h_{t-1}, x_t] + b_C) C~t=tanh(WC[ht1,xt]+bC)

input_gate = nn.Sigmoid()
i_t = input_gate(torch.matmul(W_i, torch.cat((h_{t-1}, x_t), dim=1)) + b_i)candidate_layer = nn.Tanh()
C_tilda = candidate_layer(torch.matmul(W_C, torch.cat((h_{t-1}, x_t), dim=1)) + b_C)
2.4 更新细胞状态(Cell State)

在更新细胞状态时,我们结合了遗忘门的结果 ( f t (f_t (ft) 和输入门的结果 ( i t (i_t (it) 以及候选细胞状态 ( C ~ t (\tilde{C}_t (C~t):

C t = f t ∗ C t − 1 + i t ∗ C ~ t C_t = f_t * C_{t-1} + i_t * \tilde{C}_t Ct=ftCt1+itC~t

C_t = f_t * C_{t-1} + i_t * C_tilda
2.5 输出门(Output Gate)

输出门决定了当前的隐藏状态 ( h t (h_t (ht) 是什么,同时输出门还控制了有多少细胞状态信息能够传递到下一层。首先,输出门生成一个信号,然后结合当前的细胞状态生成新的隐藏状态。

o t = σ ( W o ⋅ [ h t − 1 , x t ] + b o ) o_t = \sigma(W_o \cdot [h_{t-1}, x_t] + b_o) ot=σ(Wo[ht1,xt]+bo)

h t = o t ∗ tanh ⁡ ( C t ) h_t = o_t * \tanh(C_t) ht=ottanh(Ct)

output_gate = nn.Sigmoid()
o_t = output_gate(torch.matmul(W_o, torch.cat((h_{t-1}, x_t), dim=1)) + b_o)h_t = o_t * torch.tanh(C_t)

3. 使用 PyTorch 实现 LSTM

现在我们使用 PyTorch 实现一个 LSTM 模型,并用它来预测天气。

3.1 导入必要的库
import torch
import torch.nn as nn
import numpy as np
3.2 定义 LSTM 模型
class WeatherLSTM(nn.Module):def __init__(self, input_size, hidden_size, output_size):super(WeatherLSTM, self).__init__()self.hidden_size = hidden_sizeself.lstm = nn.LSTM(input_size, hidden_size, batch_first=True)self.fc = nn.Linear(hidden_size, output_size)def forward(self, x):h_0 = torch.zeros(1, x.size(0), self.hidden_size)  # 初始化隐藏状态c_0 = torch.zeros(1, x.size(0), self.hidden_size)  # 初始化细胞状态out, (h_n, c_n) = self.lstm(x, (h_0, c_0))  # 计算所有时间步长的输出out = self.fc(out[:, -1, :])  # 取最后一个时间步长的输出并通过线性层return out

在这个模型中,LSTM 层代替了 RNN 层,它自动处理了前面介绍的遗忘门、输入门、输出门和细胞状态的更新。

3.3 准备数据
# 生成示例数据
data = np.array([[30, 31, 32, 33, 34],[32, 33, 34, 35, 36],[35, 36, 37, 38, 39]],dtype=np.float32)labels = np.array([35, 37, 40], dtype=np.float32)  # 对应的目标值# 转换为 PyTorch 张量
data = torch.from_numpy(data).unsqueeze(-1)  # 添加特征维度
labels = torch.from_numpy(labels).unsqueeze(-1)
3.4 训练模型
# 定义超参数
input_size = 1
hidden_size = 10
output_size = 1
num_epochs = 100
learning_rate = 0.01# 实例化模型
model = WeatherLSTM(input_size, hidden_size, output_size)# 定义损失函数和优化器
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)# 训练模型
for epoch in range(num_epochs):model.train()outputs = model(data)loss = criterion(outputs, labels)optimizer.zero_grad()loss.backward()optimizer.step()if (epoch+1) % 10 == 0:print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}')
3.5 使用模型进行预测
model.eval()
with torch.no_grad():test_input = torch.tensor([[36, 37, 38, 39, 40]], dtype=torch.float32).unsqueeze(-1)predicted_temperature = model(test_input)print(f"预测的温度: {predicted_temperature.item():.2f}")

4. 完整代码

import torch
import torch.nn as nn
import numpy as npclass WeatherLSTM(nn.Module):def __init__(self, input_size, hidden_size, output_size):super(WeatherLSTM, self).__init__()self.hidden_size = hidden_sizeself.lstm = nn.LSTM(input_size, hidden_size, batch_first=True)self.fc = nn.Linear(hidden_size, output_size)def forward(self, x):h_0 = torch.zeros(1, x.size(0), self.hidden_size)c_0 = torch.zeros(1, x.size(0), self.hidden_size)out, (h_n, c_n) = self.lstm(x, (h_0, c_0))out = self.fc(out[:, -1, :])return out# 生成示例数据
data = np.array([[30, 31, 32, 33, 34],[32, 33, 34, 35, 36],[35, 36, 37, 38, 39]],dtype=np.float32)labels = np.array([35, 37, 40], dtype=np.float32)data = torch.from_numpy(data).unsqueeze(-1)
labels = torch.from_numpy(labels).unsqueeze(-1)# 定义超参数
input_size = 1
hidden_size = 10
output_size = 1
num_epochs = 100
learning_rate = 0.01# 实例化模型
model = WeatherLSTM(input_size, hidden_size, output_size)# 定义损失函数和优化器
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)# 训练模型
for epoch in range(num_epochs):model.train()outputs = model(data)loss = criterion(outputs, labels)optimizer.zero_grad()loss.backward()optimizer.step()if (epoch+1) % 10 == 0:print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}')# 进行预测
model.eval()
with torch.no_grad():test_input = torch.tensor([[36, 37, 38, 39, 40]], dtype=torch.float32).unsqueeze(-1)predicted_temperature = model(test_input)print(f"预测的温度: {predicted_temperature.item():.2f}")

5. 总结

LSTM 通过引入遗忘门、输入门、输出门和细胞状态,能够有效解决 RNN 中的长期依赖问题,使得模型可以更好地在序列数据中保留重要信息并进行预测。


http://www.ppmy.cn/embedded/100906.html

相关文章

Linux云计算 |【第二阶段】SECURITY-DAY3

主要内容: Prometheus监控服务器、Prometheus被监控端、Grafana监控可视化 补充:Zabbix监控软件不自带LNMP和DB数据库,需要自行手动安装配置;Prometheus监控软件自带WEB页面和DB数据库;Prometheus数据库为时序数据库&…

机器人走路问题优化解法

public class Test53 {//假设有N个位置,记为1-N,N大于或等于2//开始机器人在M位置上(M为1-N中的一个)//如果机器人来到1位置,那么下一步只能向右来到2位置//如果机器人来到N位置,那么下一步只能向左来到N-1…

Vue小玩意儿:vue3+express.js实现大文件分片上传

vue3: <template><div><h1>大文件分片上传</h1><input type"file" change"onFileChange"/><div v-if"progress > 0">上传进度: {{ progress }}%</div></div> </template><script …

浅谈Kafka(三)

浅谈Kafka&#xff08;三&#xff09; 文章目录 浅谈Kafka&#xff08;三&#xff09;Kafka目录介绍基础操作JMX接口消费者是否能够消费指定分区的消息生产者是否发送消息到leader创建主题时如何把分区放到不同broker中Kafka新建的分区在哪个目录创建Kafka java示例 Kafka目录介…

代码随想录算法训练营第十一天|150. 逆波兰表达式求值 、239. 滑动窗口最大值、347.前 K 个高频元素

Leetcode150. 逆波兰表达式求值 题目链接&#xff1a;150. 逆波兰表达式求值 C&#xff1a; class Solution { public:int evalRPN(vector<string>& tokens) {stack<long long> st; for (int i 0; i < tokens.size(); i) {if (tokens[i] "" …

解密网络安全:初学者指南

密码学是网络安全的基石&#xff0c;它不仅确保数据的机密性&#xff0c;还能保护数据的完整性和不可否认性。本文将带领你了解密码学的基本概念以及它在保护数据机密性中的应用。 什么是密码学&#xff1f; 当我们通过计算机网络传输数据时&#xff0c;如果无法防止他人窃听…

dubbo:dubbo+zookeeper整合nginx实现网关(四)

文章目录 0. 引言1. nginx简介2. 集成nginx2.1 负载均衡实现 3. 源码4. 总结 0. 引言 我们之前讲解过dubbozookeeper实现服务调用和注册中心&#xff0c;但是还缺乏一个统一的入口&#xff0c;即网关服务。dubbozookeeper的模式更加适合的网关组件为nginx&#xff0c;所以今天…

【Kubernetes】K8s中Container(容器)、Pod(小组)和node(节点)概念讲解

Kubernetes学习之路 第一章 Kubernetes学习入门之Container(容器)、Pod(小组)和node(节点)概念 文章目录 Kubernetes学习之路前言一、Container&#xff08;容器&#xff09;二、Pod&#xff08;小组&#xff09;1.单容器 Pod2.多容器 Pod 三、Container&#xff08;容器&…