深度学习--------------------长短期记忆网络(LSTM)

ops/2024/10/19 3:15:44/

目录

  • 长短期记忆网络
    • 候选记忆单元
    • 记忆单元
    • 隐状态
  • 长短期记忆网络代码从零实现
    • 初始化模型参数
    • 初始化
    • 实际模型
    • 训练
  • 简洁实现

长短期记忆网络

忘记门:将值朝0减少
输入门:决定要不要忽略掉输入数据
输出门:决定要不要使用隐状态。



在这里插入图片描述

在这里插入图片描述




候选记忆单元

在这里插入图片描述




记忆单元

记忆单元会把上一个时刻的记忆单元作为状态放进来,所以LSTM和RNN跟GRU不一样的地方是它的状态里面有两个独立的。
如果: F t F_t Ft等于0的话,就是希望不要记住 C t − 1 C_{t-1} Ct1
如果: I t I_t It是1的话,就是希望尽量的去用它,如果 I t I_t It等于0的话,就是把现在的记忆单元丢掉。

在这里插入图片描述

在这里插入图片描述

在这里插入图片描述




隐状态

在这里插入图片描述

在这里插入图片描述


在这里插入图片描述




长短期记忆网络代码从零实现

import torch
from torch import nn
from d2l import torch as d2l# 设置批量大小为32,时间步数为35
batch_size, num_steps = 32, 35
# 使用d2l库中的load_data_time_machine函数加载时间机器数据集,
# 并设置批量大小为32,时间步数为35,将加载的数据集赋值给train_iter和vocab变量
train_iter, vocab = d2l.load_data_time_machine(batch_size, num_steps)



初始化模型参数

def get_lstm_params(vocab_size, num_hiddens, device):# 将词汇表大小赋值给num_inputs和num_outputsnum_inputs = num_outputs = vocab_size# 定义一个辅助函数normal,用于生成具有特定形状的正态分布随机数,并将其初始化为较小的值def normal(shape):return torch.randn(size=shape, device=device) * 0.01# 定义一个辅助函数three,用于生成三个参数:输入到隐藏状态的权重矩阵、隐藏状态到隐藏状态的权重矩阵和隐藏状态的偏置项def three():return (normal((num_inputs, num_hiddens)), normal((num_hiddens, num_hiddens)),torch.zeros(num_hiddens, device=device))# 调用three函数获取输入到隐藏状态的权重矩阵W_xi、隐藏状态到隐藏状态的权重矩阵W_hi和隐藏状态的偏置项b_iW_xi, W_hi, b_i = three()# 调用three函数获取输入到隐藏状态的权重矩阵W_xf、隐藏状态到隐藏状态的权重矩阵W_hf和隐藏状态的偏置项b_fW_xf, W_hf, b_f = three()# 调用three函数获取输入到隐藏状态的权重矩阵W_xo、隐藏状态到隐藏状态的权重矩阵W_ho和隐藏状态的偏置项b_oW_xo, W_ho, b_o = three()# 调用three函数获取输入到隐藏状态的权重矩阵W_xc、隐藏状态到隐藏状态的权重矩阵W_hc和隐藏状态的偏置项b_cW_xc, W_hc, b_c = three()# 生成隐藏状态到输出的权重矩阵W_hqW_hq = normal((num_hiddens, num_outputs))# 生成输出的偏置项b_qb_q  = torch.zeros(num_outputs, device=device)# 将所有参数组合成列表paramsparams = [W_xi, W_hi, b_i, W_xf, W_hf, b_f, W_xo, W_ho, b_o, W_xc, W_hc, b_c, W_hq, b_q]    # 变量所有参数for param in params:# 将所有参数的requires_grad属性设置为True,表示需要计算梯度param.requires_grad_(True)# 返回所有参数return params



初始化

def init_lstm_state(batch_size, num_hiddens, device):# 返回一个元组,包含两个张量:一个全零张量表示初始的隐藏状态(即:H要有个初始化),和一个全零张量表示初始的记忆细胞状态(即:C要有个初始化)。return (torch.zeros((batch_size, num_hiddens), device=device),torch.zeros((batch_size, num_hiddens), device=device))



实际模型

def lstm(inputs, state, params):[W_xi, W_hi, b_i, W_xf, W_hf, b_f, W_xo, W_ho, b_o, W_xc, W_hc, b_c, W_hq, b_q] = params# 解包状态元组state,分别赋值给隐藏状态H和记忆细胞状态C(H, C) = state# 创建一个空列表用于存储每个时间步的输出outputs = []# 对于输入序列中的每个时间步for X in inputs:# 输入门的计算:使用输入、隐藏状态和偏置项,通过线性变换和sigmoid函数计算输入门I = torch.sigmoid((X @ W_xi) + (H @ W_hi) + b_i)# 遗忘门的计算:使用输入、隐藏状态和偏置项,通过线性变换和sigmoid函数计算遗忘门F = torch.sigmoid((X @ W_xf) + (H @ W_hf) + b_f)# 输出门的计算:使用输入、隐藏状态和偏置项,通过线性变换和sigmoid函数计算输出门O = torch.sigmoid((X @ W_xo) + (H @ W_ho) + b_o)# 新的记忆细胞候选值的计算:使用输入、隐藏状态和偏置项,通过线性变换和tanh函数计算新的记忆细胞候选值C_tilda = torch.tanh((X @ W_xc) + (H @ W_hc) + b_c)# 更新记忆细胞状态:将旧的记忆细胞状态与遗忘门和输入门的乘积相加,再与新的记忆细胞候选值的乘积相加,得到新的记忆细胞状态C = F * C + I * C_tilda# 更新隐藏状态:将输出门和经过tanh函数处理的记忆细胞状态的乘积作为新的隐藏状态H = O * torch.tanh(C)# 输出的计算:使用新的隐藏状态和偏置项,通过线性变换得到输出Y = (H @ W_hq) + b_q# 将当前时间步的输出添加到列表中outputs.append(Y)# 将所有时间步的输出在维度0上拼接起来,作为最终的输出结果;# 返回最终的输出结果和更新后的隐藏状态和记忆细胞状态的元组return torch.cat(outputs, dim=0), (H, C)



训练

vocab_size, num_hiddens, device = len(vocab), 256, d2l.try_gpu()
num_epochs, lr = 500, 1
# 使用d2l库中的RNNModelScratch类创建一个基于LSTM的模型对象,
model = d2l.RNNModelScratch(len(vocab), num_hiddens, device, get_lstm_params, init_lstm_state, lstm)
d2l.train_ch8(model, train_iter, vocab, lr, num_epochs, device)

在这里插入图片描述

在这里插入图片描述




简洁实现

import torch
from torch import nn
from d2l import torch as d2lbatch_size, num_steps = 32, 35
train_iter, vocab = d2l.load_data_time_machine(batch_size, num_steps)
vocab_size, num_hiddens, device = len(vocab), 256, d2l.try_gpu()
num_epochs, lr = 500, 1
num_inputs = vocab_size
# 使用nn.LSTM创建一个LSTM层,输入特征数量为num_inputs,隐藏单元数量为num_hiddens
lstm_layer = nn.LSTM(num_inputs, num_hiddens)
# 使用d2l库中的RNNModel类创建一个基于LSTM的模型对象,传入LSTM层和词汇表大小
model = d2l.RNNModel(lstm_layer, len(vocab))
mode = model.to(device)
d2l.train_ch8(model, train_iter, vocab, lr, num_epochs, device)
d2l.plt.show()

在这里插入图片描述

在这里插入图片描述


http://www.ppmy.cn/ops/119885.html

相关文章

Apache安装后无法启动的问题“不能再本地计算机启动apache”

首先安装 参考这位博主的小白下载和安装Apache的教程(保姆级) 遇到的问题 在启动的时候遇到问题 说apache不能在本地计算机启动 解决方法 1. 路径检查 首先!!! 请仔细检查你的httpd.conf文件中的Apache路径是否…

Unity开发绘画板——04.笔刷大小调节

笔刷大小调节 上面的代码中其实我们已经提供了笔刷大小的字段,即brushSize,现在只需要将该字段和界面中的Slider绑定即可,Slider值的范围我们设置为1~20 代码中只需要做如下改动: public Slider brushSizeSlider; //控制笔刷大…

javax.net.ssl.SSLHandshakeException: Chain validation failed

异常描述&#xff1a; D/OkHttp: <-- HTTP FAILED: javax.net.ssl.SSLHandshakeException: Chain validation failed com.bfmd.okhttpsample I/Main: error: Chain validation failed异常解决&#xff1a; 解决方法一&#xff1a; 解决方法很简单&#xff0c;检查一下设备…

十七、触发器

文章目录 0. 引入1. 触发器概述2. 触发器的创建2.1 触发器的创建2.2 代码举例 3. 查看、删除触发器3.1 查看触发器3.2 删除触发器 4. 触发器的优缺点4.1 优点4.2 缺点4.3 注意点 0. 引入 在实际开发中&#xff0c;我们经常会遇到这样的情况&#xff1a;有 2 个或者多个相互关联…

物联网智能项目研究

物联网&#xff08;IoT&#xff09;作为当今数字化转型的重要推动力&#xff0c;正在改变我们的生活方式和工作模式。从智能家居、智慧城市到工业自动化&#xff0c;物联网技术的应用正在实现人们对智能生活的向往。本文将探讨一个具体的物联网智能项目&#xff0c;通过实际操作…

vue2与vue3知识点

1.vue2&#xff08;optionsAPI&#xff09;选项式API 2.vue3&#xff08;composition API&#xff09;响应式API vue3 setup 中this是未定义&#xff08;undefined&#xff09;vue3中已经开始弱化this vue2通过this可以拿到vue3setup定义得值和方法 setup语法糖 ref > …

3. 轴指令(omron 机器自动化控制器)——>MC_MoveFeed

机器自动化控制器——第三章 轴指令 8 MC_MoveFeed变量▶输入变量▶输出变量▶输入输出变量 功能说明▶指令详情▶时序图▶重启运动指令▶多重启动运动指令▶异常 示例程序▶参数设定▶动作示例▶梯形图▶结构文本(ST) MC_MoveFeed 指定自外部输入的中断输入发生位置起的移动距…