82.长短期记忆网络(LSTM)以及代码实现

news/2024/12/29 15:37:40/

1. 长短期记忆网络

  • 忘记门:将值朝0减少
  • 输入门:决定不是忽略掉输入数据
  • 输出门:决定是不是使用隐状态

2. 门

在这里插入图片描述

3. 候选记忆单元

在这里插入图片描述

4. 记忆单元

在这里插入图片描述

5. 隐状态

在这里插入图片描述

6. 总结

在这里插入图片描述

7. 从零实现的代码

我们首先加载时光机器数据集。

import torch
from torch import nn
from d2l import torch as d2lbatch_size, num_steps = 32, 35
train_iter, vocab = d2l.load_data_time_machine(batch_size, num_steps)

7.1 初始化模型参数

接下来,我们需要定义和初始化模型参数。 如前所述,超参数num_hiddens定义隐藏单元的数量。 我们按照标准差 0.01 的高斯分布初始化权重,并将偏置项设为 0 。

def get_lstm_params(vocab_size, num_hiddens, device):num_inputs = num_outputs = vocab_sizedef normal(shape):return torch.randn(size=shape, device=device)*0.01def three():return (normal((num_inputs, num_hiddens)),normal((num_hiddens, num_hiddens)),torch.zeros(num_hiddens, device=device))W_xi, W_hi, b_i = three()  # 输入门参数W_xf, W_hf, b_f = three()  # 遗忘门参数W_xo, W_ho, b_o = three()  # 输出门参数W_xc, W_hc, b_c = three()  # 候选记忆元参数# 输出层参数W_hq = normal((num_hiddens, num_outputs))b_q = torch.zeros(num_outputs, device=device)# 附加梯度params = [W_xi, W_hi, b_i, W_xf, W_hf, b_f, W_xo, W_ho, b_o, W_xc, W_hc,b_c, W_hq, b_q]for param in params:param.requires_grad_(True)return params

7.2 定义模型

初始化函数中, 长短期记忆网络的隐状态需要返回一个额外的记忆元, 单元的值为0,形状为(批量大小,隐藏单元数)。 因此,我们得到以下的状态初始化。

# C和H都要初始化
def init_lstm_state(batch_size, num_hiddens, device):return (torch.zeros((batch_size, num_hiddens), device=device),torch.zeros((batch_size, num_hiddens), device=device))

实际模型的定义与我们前面讨论的一样: 提供三个门和一个额外的记忆元。 请注意,只有隐状态才会传递到输出层, 而记忆元 𝐂𝑡 不直接参与输出计算

def lstm(inputs, state, params):[W_xi, W_hi, b_i, W_xf, W_hf, b_f, W_xo, W_ho, b_o, W_xc, W_hc, b_c,W_hq, b_q] = params(H, C) = stateoutputs = []for X in inputs:I = torch.sigmoid((X @ W_xi) + (H @ W_hi) + b_i)F = torch.sigmoid((X @ W_xf) + (H @ W_hf) + b_f)O = torch.sigmoid((X @ W_xo) + (H @ W_ho) + b_o)C_tilda = torch.tanh((X @ W_xc) + (H @ W_hc) + b_c)C = F * C + I * C_tildaH = O * torch.tanh(C)Y = (H @ W_hq) + b_qoutputs.append(Y)return torch.cat(outputs, dim=0), (H, C)

7.3 训练和预测

让我们通过实例化rnn_scratch中 引入的RNNModelScratch类来训练一个长短期记忆网络, 就如我们在gru中所做的一样。

vocab_size, num_hiddens, device = len(vocab), 256, d2l.try_gpu()
num_epochs, lr = 500, 1
model = d2l.RNNModelScratch(len(vocab), num_hiddens, device, get_lstm_params,init_lstm_state, lstm)
d2l.train_ch8(model, train_iter, vocab, lr, num_epochs, device)

运行结果:

在这里插入图片描述

8. 简洁实现

使用高级API,我们可以直接实例化LSTM模型。 高级API封装了前文介绍的所有配置细节。 这段代码的运行速度要快得多, 因为它使用的是编译好的运算符而不是Python来处理之前阐述的许多细节。

num_inputs = vocab_size
lstm_layer = nn.LSTM(num_inputs, num_hiddens)
model = d2l.RNNModel(lstm_layer, len(vocab))
model = model.to(device)
d2l.train_ch8(model, train_iter, vocab, lr, num_epochs, device)

运行结果:

在这里插入图片描述

实际情况下,LSTM和GRU用哪个都可以,性能差不多。

长短期记忆网络是典型的具有重要状态控制的隐变量自回归模型。 多年来已经提出了其许多变体,例如,多层、残差连接、不同类型的正则化。 然而,由于序列的长距离依赖性,训练长短期记忆网络 和其他序列模型(例如门控循环单元)的成本是相当高的。 在后面的内容中,我们将讲述更高级的替代模型,如Transformer

9. Q&A


http://www.ppmy.cn/news/13132.html

相关文章

2023年,PMP认证考试的心得分享

对于刚开始要准备参加PMP考试的人,大多应该都是不知道怎么去考试复习好的。PMP认证考试虽是美国的考试,但其实这跟国内其它的考试复习也差不多,没有什么很特别之处,只是多了一个中英互译,再就是学习的内容不一样&#…

【微电网】基于改进粒子群算法的微电网优化调度(Matlab代码实现)

💥💥💞💞欢迎来到本博客❤️❤️💥💥 🏆博主优势:🌞🌞🌞博客内容尽量做到思维缜密,逻辑清晰,为了方便读者。 ⛳️座右铭&a…

【面试题】2023年前端最新面试题-web安全

原文见语雀:(https://www.yuque.com/deepstates/interview/tluabi) ● 网络前端安全 ○ xss ○ csrf ○ 点击挟持攻击 ○ url跳转漏洞 ○ sql注入攻击 ○ os命令注入攻击 ○ 海量接口请求 ● 前端兜底安全 ○ 兜底容灾 ⭐️⭐️⭐️ 相关知…

【进击的算法】基础算法——回溯算法

🍿本节主题:回溯算法 🎈更多算法:深入聊聊KMP算法 💕我的主页:蓝色学者的个人主页 文章目录一、前言二、概念三、例题1.题目:全排列2.解题思路回溯算法的本质问题1:问题2&#xff1a…

记2022年秋招经历

自我介绍求职体验求职心得 一、自我介绍 学历普通本科,专业是网络工程,在校期间学习主要的是计算机体系方面的知识,根据课程,自学过前端、后端等内容。包括前端三板斧(htmlcssjs)、常用的前端框架(bootstarp/Vue等)&am…

PostgreSQL 复制表的 5 种方式

PostgreSQL 提供了多种不同的复制表的方法,它们的差异在于是否需要复制表结构或者数据。 CREATE TABLE AS SELECT 语句 CREATE TABLE AS SELECT 语句可以用于复制表结构和数据,但是不会复制索引。 我们可以使用以下语句基于 employee 复制一个新表 em…

1个 30多年程序员的生涯经验总结

有人说:一个人从1岁活到80岁很平凡,但如果从80岁倒着活,那么一半以上的人都可能不凡。 生活没有捷径,我们踩过的坑都成为了生活的经验,这些经验越早知道,你要走的弯路就会越少。 在我30多年的程序员生涯里…

电脑开机出现绿屏错误无法启动怎么办?

电脑开机出现绿屏错误无法启动怎么办?有用户电脑开机的时候,突然出现了屏幕变成绿色的情况,而且上面有很多的错误代码。然后卡在页面上一直无法进入到桌面,重启电脑后依然无效。那么如何去解决这个问题呢?来看看具体的…