边写代码边学习之LSTM

news/2025/2/22 22:06:27/

1.  什么是LSTM

长短期记忆网络 LSTM(long short-term memory)是 RNN 的一种变体,其核心概念在于细胞状态以及“门”结构。细胞状态相当于信息传输的路径,让信息能在序列连中传递下去。你可以将其看作网络的“记忆”。理论上讲,细胞状态能够将序列处理过程中的相关信息一直传递下去。因此,即使是较早时间步长的信息也能携带到较后时间步长的细胞中来,这克服了短时记忆的影响。信息的添加和移除我们通过“门”结构来实现,“门”结构在训练过程中会去学习该保存或遗忘哪些信息。
 

在这里插入图片描述

 

 

2. 实验代码

2.1. 搭建一个只有一层RNN和Dense网络的模型。

2.2. 验证LSTM里的逻辑

 假设我的输入数据是x = [1,0], 

kernel = [[[2, 1, 1, 0, 0, 0, 0, 1, 1, 0, 1, 0],

              [1, 1, 0, 1, 1, 0, 0, 1, 1 ,0, 0, 0],]]

recurrent_kernel = [[1, 0, 0, 1, 2,1,0,1,2,0,1,0],

                              [1, 1, 0, 0, 2,1,0,1,2,2,0,0],

                              [1, 0, 1, 2, 0,1,0,1,1,0,1,0]]

biase = [3, 1, 0, 1, 1,0,0,1,0,2,0.0,0]

通过下面手算,h的结果是[0, 4,1], c 的结果是[0,4,1].  注意无激活函数。

代码验证上面的结果


def change_weight():# Create a simple Dense layerlstm_layer = LSTM(units=3, input_shape=(3, 2), activation=None, recurrent_activation=None, return_sequences=True,return_state= True)# Simulate input data (batch size of 1 for demonstration)input_data = np.array([[[1.0, 2], [2, 3], [3, 4]],[[5, 6], [6, 7], [7, 8]],[[9, 10], [10, 11], [11, 12]]])# Pass the input data through the layer to initialize the weights and biaseslstm_layer(input_data)kernel, recurrent_kernel, biases = lstm_layer.get_weights()# Print the initial weights and biasesprint("recurrent_kernel:", recurrent_kernel, recurrent_kernel.shape ) # (3,3)print('kernal:',kernel, kernel.shape) #(2,3)print('biase: ',biases , biases.shape) # (3)kernel = np.array([[2, 1, 1, 0, 0, 0, 0, 1, 1, 0, 1, 0],[1, 1, 0, 1, 1, 0, 0, 1, 1 ,0, 0, 0],])recurrent_kernel = np.array([[1, 0, 0, 1, 2,1,0,1,2,0,1,0],[1, 1, 0, 0, 2,1,0,1,2,2,0,0],[1, 0, 1, 2, 0,1,0,1,1,0,1,0]])biases = np.array([3, 1, 0, 1, 1,0,0,1,0,2,0.0,0])lstm_layer.set_weights([kernel, recurrent_kernel, biases])print(lstm_layer.get_weights())# test_data = np.array([#     [[1.0, 3], [1, 1], [2, 3]]# ])test_data = np.array([[[1,0.0]]])output, memory_state, carry_state  = lstm_layer(test_data)print(output)print(memory_state)print(carry_state)
if __name__ == '__main__':change_weight()

执行结果:

recurrent_kernel: [[-0.36744034 -0.11181469 -0.10642298  0.5450207  -0.30208975  0.54054320.09643812 -0.14983998  0.1859854   0.2336958  -0.16187981  0.11621032][ 0.07727922 -0.226477    0.1491096  -0.03933501  0.31236103 -0.129630920.10522162 -0.4815724  -0.2093935   0.34740582 -0.60979587 -0.15877807][ 0.15371156  0.01244636 -0.09840634 -0.32093546  0.06523462  0.189349320.38859126 -0.3261706  -0.05138849  0.42713478  0.49390993  0.37013963]] (3, 12)
kernal: [[-0.47606698 -0.43589187 -0.5371355  -0.07337284  0.30526626 -0.18241835-0.03675252  0.2873094   0.33218485  0.24838251  0.17765659  0.4312396 ][ 0.4007727   0.41280174  0.40750778 -0.6245315   0.6382301   0.428892250.11961156 -0.6021105  -0.43556038  0.39798307  0.6390712   0.16719025]] (2, 12)
biase:  [0. 0. 0. 1. 1. 1. 0. 0. 0. 0. 0. 0.] (12,)
[array([[2., 1., 1., 0., 0., 0., 0., 1., 1., 0., 1., 0.],[1., 1., 0., 1., 1., 0., 0., 1., 1., 0., 0., 0.]], dtype=float32), array([[1., 0., 0., 1., 2., 1., 0., 1., 2., 0., 1., 0.],[1., 1., 0., 0., 2., 1., 0., 1., 2., 2., 0., 0.],[1., 0., 1., 2., 0., 1., 0., 1., 1., 0., 1., 0.]], dtype=float32), array([3., 1., 0., 1., 1., 0., 0., 1., 0., 2., 0., 0.], dtype=float32)]
tf.Tensor([[[0. 4. 0.]]], shape=(1, 1, 3), dtype=float32)
tf.Tensor([[0. 4. 0.]], shape=(1, 3), dtype=float32)
tf.Tensor([[0. 4. 1.]], shape=(1, 3), dtype=float32)

可以看出h=[0,4,0], c=[0,4,1]


http://www.ppmy.cn/news/1009646.html

相关文章

如何在 Android 上恢复已删除的视频|快速找回丢失的记忆

想知道是否有任何成功的方法可以从 Android 手机中检索已删除的视频?好吧,本指南将向您展示分步说明,让您轻松从手机中找回丢失的视频文件! 您是否不小心从 Android 智能手机中删除了珍贵的生日视频?难道是无处可寻吗…

Julia 文件(File)读写

Julia 提供了一些基本的函数来处理文件: open() - 打开文件read() - 读取文件内容close() - 关闭文件 从文件读取或者写入数据需要使用文件句柄。 文件句柄其实就是一个指针,指针就是指向文件中的某个位置。 从一个文件读取数据,应用程序…

K8S系列文章之 内外网如何互相访问

K8S中网络这块主要考虑 如何访问外部网络以及外部如何访问内部网络 访问外网服务的两种方式 需求 k8s集群内的pod需要访问mysql,由于mysql的性质,不适合部署在k8s集群内,故k8s集群内的应用需要链接mysql时,需要配置链接外网的mysql,本次测试…

Vue_02:详细语法以及代码示例 + 知识点练习 + 综合案例(第二期)

2023年8月4日15:25:01 Vue_02_note 在Vue中,非相应式数据,直接往实例上面挂载就可以了。 01_Vue 指令修饰符 什么是指令修饰符呢? 答: 通过 " . " 指明一些指令后缀,不同 后缀 封装了不同的处理操作 —…

设计模式之五:单例模式

有些对象只需要有一个,比如线程池、缓存和注册表等。 对比全局变量,其需要在程序开始就创建好对象,如果这个对象比较耗资源,而在后面的执行过程中又一直没有用到,就造成了浪费。 class Singleton {private:static Si…

Go如何构建高效API接口| 青训营

Go语言作为一个高效的静态类型语言,在构建API服务时也表现出了很大的优势。本文将介绍如何使用Go语言构建高效的API服务,帮助开发者更好地应对日益增长的API需求。 一、选择适合的框架 首先,选择适合的框架是构建高效API服务的重要一步。在…

uni-app 开发踩坑

1、自定义组件名称要大写开头 2、[“usingComponents”][“uni-api”] not found 关闭小程序调试,重新打开调试 封装 uni.request https://blog.csdn.net/mostrichman/article/details/130409476 https://blog.csdn.net/a1769815985/article/details/126018110 …

【性能测试】关于系统用户数,并发用户数,在线用户数,吞吐量

目录 1、概念 系统用户数 在线用户数 并发用户数 计算公式 2、吞吐量 资料获取方法 1、概念 系统用户数 狭义上来说,可以理解为系统注册用户数;广义上来说,可以理解为所有访问过系统的用户数 在线用户数 狭义上来说,可以…