NNDL 作业11 LSTM

news/2024/12/26 12:22:04/

习题6-4  推导LSTM网络中参数的梯度, 并分析其避免梯度消失的效果

先来推个实例:

看式子中间,上半部分并未有连乘项,而下半部分有C_tC_{t-1}的连乘项,从这可以看出,LSTM能缓解梯度消失,梯度爆炸只是不易发生。

下面咱们来求一下:\frac{\partial C_{t}}{\partial C_{t-1}}

\frac{\partial C_{t}}{\partial C^{t-1}}=\frac{\partial C_{t}}{\partial F_{t}}\frac{\partial F_{t}}{\partial H_{t-1}}\frac{\partial H_{t-1}}{\partial C_{t-1}}+\frac{\partial C_t}{\partial I_t}\frac{\partial I_t}{\partial H_{t-1}}\frac{\partial H_{t-1}}{\partial C_{t-1}}+\frac{\partial C_{t}}{\partial\tilde{C}_{t}}\:\frac{\partial\tilde{C}_{t}}{\partial H_{t-1}}\frac{\partial H_{t-1}}{\partial C_{t-1}}+\frac{\partial C_{t}}{\partial C_{t-1}}

展开得:

\begin{aligned}&\frac{\partial C_t}{\partial C^{t-1}}=C_{t-1}\sigma^{\prime}(\cdot)U_{f}*O_{t-1}tanh^{\prime}(C_{t-1})+\widetilde{C}_t\sigma^{\prime}U_{i}*O_{t-1}tanh^{\prime}(C_{t-1})+\\&I_ttanh^{\prime}(\cdot)U_{c}*O_{t-1}tanh^{\prime}(C_{t-1})+F_t\end{aligned}

通过调节U_{f}U_{i}U_{h}来使\frac{\partial C_{t}}{\partial C_{t-1}}接近于1,从而防止梯度消失太快。

此问题我是看的视频学习的,参考链接:【【重温经典】大白话讲解LSTM长短期记忆网络  如何缓解梯度消失,手把手公式推导反向传播】https://www.bilibili.com/video/BV1qM4y1M7Nv?p=5&vd_source=d58e25af805a85358e5bc9060257ecdd

习题6-3P  编程实现下图LSTM运行过程

同学提出,未发现h_{t-1}输入。可以适当改动例题,增加该输入。

实现LSTM算子,可参考实验教材代码。

1. 使用Numpy实现LSTM算子

import numpy as np#定义激活函数
def sigmoid(x):return 1/(1+np.exp(-x))#权重
input_weight=np.array([1,0,0,0])
inputgate_weight=np.array([0,100,0,-10])
forgetgate_weight=np.array([0,100,0,10])
outputgate_weight=np.array([0,0,100,-10])#输入
input=np.array([[1,0,0,1],[3,1,0,1],[2,0,0,1],[4,1,0,1],[2,0,0,1],[1,0,1,1],[3,-1,0,1],[6,1,0,1],[1,0,1,1]])y=[]   #输出
c_t=0  #内部状态for x in input:g_t=np.matmul(input_weight,x) #候选状态i_t=np.round(sigmoid(np.matmul(inputgate_weight,x)))  #输入门after_inputgate=g_t*i_t       #候选状态经过输入门f_t=np.round(sigmoid(np.matmul(forgetgate_weight,x))) #遗忘门after_forgetgate=f_t*c_t      #内部状态经过遗忘门c_t=np.add(after_inputgate,after_forgetgate) #新的内部状态o_t=np.round(sigmoid(np.matmul(outputgate_weight,x))) #输出门after_outputgate=o_t*c_t     #新的内部状态经过输出门y.append(after_outputgate)   #输出print('输出:',y)

运行结果:

2. 使用nn.LSTMCell实现

import numpy as np
import torch
import torch.nn as nn#实例化
input_size=4
hidden_size=1
cell=nn.LSTMCell(input_size=input_size,hidden_size=hidden_size)
#修改模型参数 weight_ih.shape=(4*hidden_size, input_size),weight_hh.shape=(4*hidden_size, hidden_size),
#weight_ih、weight_hh分别为输入x、隐层h分别与输入门、遗忘门、候选、输出门的权重
cell.weight_ih.data=torch.tensor([[0,100,0,-10],[0,100,0,10],[1,0,0,0],[0,0,100,-10]],dtype=torch.float32)
cell.weight_hh.data=torch.zeros(4,1)
print('cell.weight_ih.shape:',cell.weight_ih.shape)
print('cell.weight_hh.shape',cell.weight_hh.shape)
#初始化h_0,c_0
h_t=torch.zeros(1,1)
c_t=torch.zeros(1,1)
#模型输入input_0.shape=(batch,seq_len,input_size)
input_0=torch.tensor([[[1,0,0,1],[3,1,0,1],[2,0,0,1],[4,1,0,1],[2,0,0,1],[1,0,1,1],[3,-1,0,1],[6,1,0,1],[1,0,1,1]]],dtype=torch.float32)
#交换前两维顺序,方便遍历input.shape=(seq_len,batch,input_size)
input=torch.transpose(input_0,1,0)
print('input.shape:',input.shape)
output=[]
#调用
for x in input:h_t,c_t=cell(x,(h_t,c_t))output.append(np.around(h_t.item(), decimals=3))#保留3位小数
print('output:',output)

运行结果:

3. 使用nn.LSTM实现

import numpy as np
import torch.nn# 设置参数
input_size = 4
hidden_size = 1
# 模型实例化
Lstm = torch.nn.LSTM(input_size=input_size, hidden_size=hidden_size, batch_first=True)
# 权重
Lstm.weight_ih_l0.data = torch.tensor([[0, 100, 0, -10], [0, 100, 0, 10], [1, 0, 0, 0], [0, 0, 100, -10]],dtype=torch.float32)
Lstm.weight_hh_l0.data = torch.zeros(4, 1)
# 初始化内部状态
h_t = torch.zeros(1, 1, 1)
c_t = torch.zeros(1, 1, 1)
# 输入的数据[batch_size,seq_len,input_size]
input = torch.tensor([[[1, 0, 0, 1], [3, 1, 0, 1], [2, 0, 0, 1], [4, 1, 0, 1], [2, 0, 0, 1], [1, 0, 1, 1],[3, -1, 0, 1], [6, 1, 0, 1], [1, 0, 1, 1]]], dtype=torch.float32)
y, (h_t, c_t) = Lstm(input, (h_t, c_t))
y = torch.round(y * 1000) / 1000
print(f"输出:{y}")

输出结果:

REF:

李宏毅机器学习笔记:RNN循环神经网络_李宏毅机器学习课程笔记-CSDN博客

RNN与LSTM详解

NNDL 作业十一 LSTM-CSDN博客

【【重温经典】大白话讲解LSTM长短期记忆网络  如何缓解梯度消失,手把手公式推导反向传播】https://www.bilibili.com/video/BV1qM4y1M7Nv?p=5&vd_source=d58e25af805a85358e5bc9060257ecdd


http://www.ppmy.cn/news/1558265.html

相关文章

#渗透测试#漏洞挖掘#红蓝攻防#护网#sql注入介绍11基于XML的SQL注入(XML-Based SQL Injection)

免责声明 本教程仅为合法的教学目的而准备,严禁用于任何形式的违法犯罪活动及其他商业行为,在使用本教程前,您应确保该行为符合当地的法律法规,继续阅读即表示您需自行承担所有操作的后果,如有异议,请立即停…

如何确保数据大屏的交互设计符合用户需求?(附实践资料下载)

确保数据大屏的交互设计符合用户需求是一个多步骤的过程,涉及到用户研究、设计原则、原型测试和持续迭代。以下是一些关键步骤和策略: 用户研究: 目标用户识别:明确大屏的目标用户群体,包括他们的背景、角色和需求。用…

C 进阶 — 程序环境和预处理

C 进阶 — 程序环境和预处理 主要内容 程序的编译和执行环境 C 程序编译和链接 预定义符号 预处理指令 #define 预处理指令 #include 预处理指令 #undef 预处理操作符 # 和 ## 宏和函数对比 命令行定义 条件编译 一 程序的编译和执行环境 ANSI C 存在两个不同环境…

Docker怎么关闭容器开机自启,批量好几个容器一起操作?

环境: WSL2 docker v25 问题描述: Docker怎么关闭容器开机自启,批量好几个容器一起操作? 解决方案: 在 Docker 中,您可以使用多种方法来关闭容器并配置它们是否在系统启动时自动启动。以下是具体步骤和…

第四节、电机定角度转动【51单片机-L298N-步进电机教程】

摘要:本节介绍电机转动角度计算步骤,从而控制步进电机转角 一、 计算过程 1.1 L28N每控制步进电机转动一步,根据程序拍数设置情况,计算步进电机步距角度step_x s t e p x s t e p X … … ① step_{x} \frac{step}{X} ……① s…

Zettlr(科研笔记) v3.4.1 中文版

Zettlr是款适合写作者和研究人员使用的Markdown编辑器,免费开源,功能简洁,具备Markdown所有基本功能,内置各种运算符,还可以调用计数器,可以完美替代Word和收费的文字处理器。 软件特点 从应用程序中直接管…

ROS1入门教程6:复杂行为处理

一、新建项目 # 创建工作空间 mkdir -p demo6/src && cd demo6# 创建功能包 catkin_create_pkg demo roscpp rosmsg actionlib_msgs message_generation tf二、创建行为 # 创建行为文件夹 mkdir action && cd action# 创建行为文件 vim Move.action# 定义行为…

论文阅读--Variational quantum algorithms

文献类型:期刊论文 作者:M. Cerezo(Los Alamos National Laboratory) 年份:2021 期刊:Nature 影响因子:44.8 摘要:由于计算成本极高,模拟复杂量子系统或解决大规模线性代…