LSTM网络:一种强大的时序数据建模工具

news/2024/10/28 0:16:45/

❤️觉得内容不错的话,欢迎点赞收藏加关注😊😊😊,后续会继续输入更多优质内容❤️

👉有问题欢迎大家加关注私戳或者评论(包括但不限于NLP算法相关,linux学习相关,读研读博相关......)👈

LSTM

(封面图由ERNIE-ViLG AI 作画大模型生成)

LSTM网络:一种强大的时序数据建模工具

在日常生活和工作中,我们经常会遇到各种类型的时序数据,如股票价格、天气数据、心电图、语音识别、自然语言处理等。这些数据具有时间依赖性,不同时间点的数据之间存在关联性。而LSTM网络是一种非常适合处理时序数据的神经网络,已经被广泛应用于各种任务中。本文将介绍LSTM网络的原理、优势和劣势,并结合代码和案例进行实践演示。

1. LSTM网络原理

LSTM(Long Short-Term Memory)网络是一种循环神经网络(Recurrent Neural Network, RNN)的变种。相比于传统的RNN,LSTM网络有着更强的长时记忆和远距离依赖处理能力,能够有效地避免梯度消失和梯度爆炸问题。

LSTM网络包括三个门控单元,分别是输入门(input gate)、遗忘门(forget gate)和输出门(output gate)。这些门控单元可以选择性地控制信息的流动,以达到记忆和遗忘的目的。除此之外,LSTM网络还有一个记忆单元(memory cell),用来存储长期的信息。

LSTM网络的计算过程可以分为以下几个步骤:

  • 输入门的计算:输入门决定了当前输入的信息在多大程度上被传递到记忆单元中。输入门的输出值为 iti_tit,计算公式如下:
    it=σ(Wixt+Uiht−1+bi)i_t = \sigma(W_i x_t + U_i h_{t-1} + b_i)it=σ(Wixt+Uiht1+bi)
    其中,xtx_txt 表示当前时刻的输入,ht−1h_{t-1}ht1 表示上一个时刻的隐藏状态,WiW_iWiUiU_iUibib_ibi 是可学习的参数,σ\sigmaσ 是sigmoid函数。

  • 遗忘门的计算:遗忘门决定了哪些历史信息需要被遗忘。遗忘门的输出值为 ftf_tft,计算公式如下:
    ft=σ(Wfxt+Ufht−1+bf)f_t = \sigma(W_f x_t + U_f h_{t-1} + b_f)ft=σ(Wfxt+Ufht1+bf)
    其中,WfW_fWfUfU_fUfbfb_fbf 是可学习的参数。

  • 记忆单元的更新:根据输入门的输出值和遗忘门的输出值,可以计算出当前时刻的记忆单元 CtC_tCt,计算公式如下:
    tanh⁡(Wcxt+Ucht−1+bc)\tanh(W_c x_t + U_c h_{t-1} + b_c)tanh(Wcxt+Ucht1+bc)
    其中,⊙\odot 表示逐元素乘积,WcW_cWcUcU_cUcbcb_cbc 是可学习的参数,tanh⁡\tanhtanh 是双曲正切函数。

  • 输出门的计算:输出门决定了当前时刻的输出值。输出门的输出值为 oto_tot,计算公式如下:
    ot=σ(Woxt+Uoht−1+bo)o_t = \sigma(W_o x_t + U_o h_{t-1} + b_o)ot=σ(Woxt+Uoht1+bo)
    其中,WoW_oWoUoU_oUobob_obo 是可学习的参数。

  • 隐藏状态的计算:根据当前时刻的记忆单元和输出门的输出值,可以计算出当前时刻的隐藏状态 hth_tht,计算公式如下:
    ht=ot⊙tanh⁡(Ct)h_t = o_t \odot \tanh(C_t)ht=ottanh(Ct)
    LSTM网络通过这些门控单元的选择性连接,实现了对时序数据的长期依赖性建模。同时,由于LSTM网络中的梯度可以通过记忆单元从一层传递到另一层,可以有效地避免梯度消失和梯度爆炸问题,提高了训练效率和模型的准确性。

2. LSTM网络的优势和劣势

  • 优势:
    (1)长期依赖性建模能力强:LSTM网络具有很好的长期依赖性建模能力,能够很好地处理时序数据中的长期依赖关系。
    (2)避免梯度消失和梯度爆炸问题:LSTM网络中的梯度可以通过记忆单元从一层传递到另一层,可以有效地避免梯度消失和梯度爆炸问题。
    (3)可适应不同长度的时序数据:LSTM网络中的记忆单元可以自适应地存储不同长度的时序数据,不需要事先指定固定长度。

  • 劣势:
    (1)计算复杂度高:LSTM网络中有多个门控单元和记忆单元,计算复杂度较高,需要更多的计算资源。
    (2)需要大量的数据训练
    :LSTM网络具有很多可调参数,需要大量的数据进行训练,否则容易出现过拟合现象。

3. 案例演示

为了更好地理解LSTM网络的应用,本文选取了一个经典的时序数据建模问题:股票价格预测。我们将使用Keras深度学习框架,使用LSTM网络对股票价格进行预测。

(1)数据预处理

首先,我们需要对股票价格数据进行预处理。我们选择了纽约证券交易所上市的Apple公司(AAPL)的历史股票价格数据,该数据包含了从1980年到2021年的日交易数据。我们将使用前70%的数据作为训练集,后30%的数据作为测试集。

在预处理数据之前,我们需要导入相关的库:

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.preprocessing import MinMaxScaler

接下来,我们加载股票价格数据,并按照训练集和测试集的比例进行拆分:

df = pd.read_csv('AAPL.csv')
df.head()

我们可以看到,数据包含日期、开盘价、最高价、最低价、收盘价、成交量和股票调整后的收盘价。我们只需要使用调整后的收盘价作为特征进行建模。

# 只使用调整后的收盘价作为特征
data = df.filter(['Adj Close']).values# 拆分训练集和测试集
training_data_len = int(len(data) * 0.7)
train_data = data[0:training_data_len]
test_data = data[training_data_len:]

接下来,我们需要对数据进行归一化处理,使得所有特征都在0到1之间。

# 归一化处理
scaler = MinMaxScaler(feature_range=(0, 1))
scaled_train_data = scaler.fit_transform(train_data)
scaled_test_data = scaler.transform(test_data)

(2)创建LSTM模型

接下来,我们需要创建LSTM模型。在Keras中,我们可以使用LSTM层来创建LSTM模型。首先,我们需要指定LSTM层中的参数,包括LSTM单元的数量、输入序列的长度和输出序列的长度。

from keras.models import Sequential
from keras.layers import LSTM, Dense# 指定LSTM模型参数
lstm_units = 50
input_seq_len = 60
output_seq_len = 30

接下来,我们创建LSTM模型。模型包含一个LSTM层和一个全连接层。在LSTM层中,我们使用50个LSTM单元,输入序列的长度为60,输出序列的长度为30。在全连接层中,我们使用一个神经元作为输出层。

model = Sequential()
model.add(LSTM(units=lstm_units, input_shape=(input_seq_len, 1)))
model.add(Dense(units=1))
model.compile(loss='mean_squared_error', optimizer='adam')
model.summary()

(3)训练模型

接下来,我们需要训练模型。在训练模型之前,我们需要将训练数据划分成输入序列和输出序列。

def create_sequences(data, input_seq_len, output_seq_len):x = []y = []for i in range(len(data)-input_seq_len-output_seq_len+1):x.append(data[i:i+input_seq_len])y.append(data[i+input_seq_len:i+input_seq_len+output_seq_len])return np.array(x), np.array(y)
train_x, train_y = create_sequences(scaled_train_data, input_seq_len, output_seq_len)
test_x, test_y = create_sequences(scaled_test_data, input_seq_len, output_seq_len)

接下来,我们可以使用train_x和train_y训练模型:

history = model.fit(train_x, train_y, epochs=50, batch_size=32, validation_split=0.1, verbose=1)我们可以使用Matplotlib绘制训练和验证损失的曲线:```python
# 绘制训练和验证损失的曲线
plt.plot(history.history['loss'], label='Training Loss')
plt.plot(history.history['val_loss'], label='Validation Loss')
plt.legend()
plt.show()

(4)模型预测

训练完成后,我们可以使用模型对测试集中的股票价格进行预测。由于我们使用了30个股票价格作为输出序列的长度,因此每次预测时,我们需要使用前60个价格作为输入序列。

def predict_future(model, data, input_seq_len, output_seq_len):predicted_data = []for i in range(len(data)-input_seq_len-output_seq_len+1):input_data = data[i:i+input_seq_len]predicted_seq = []for j in range(output_seq_len):predicted_price = model.predict(input_data.reshape((1, input_seq_len, 1)))[0][0]predicted_seq.append(predicted_price)input_data = np.append(input_data[1:], [[predicted_price]], axis=0)predicted_data.append(predicted_seq)return np.array(predicted_data)predicted_data = predict_future(model, scaled_test_data, input_seq_len, output_seq_len)
predicted_data = scaler.inverse_transform(predicted_data.reshape((-1, output_seq_len)))
test_data = scaler.inverse_transform(test_y.reshape((-1, output_seq_len)))

接下来,我们可以绘制预测结果和实际结果的图表:

# 绘制预测结果和实际结果的图表
plt.figure(figsize=(10, 6))
plt.plot(test_data, label='Actual')
plt.plot(predicted_data.flatten(), label='Predicted')
plt.legend()
plt.show()

完整的代码如下所示:

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.preprocessing import MinMaxScaler
from keras.models import Sequential
from keras.layers import LSTM, Dense# 加载股票价格数据
df = pd.read_csv('AAPL.csv')# 只使用调整后的收盘价作为特征
data = df.filter(['Adj Close']).values# 拆分训练集和测试集
training_data_len = int(len(data) * 0.7)
train_data = data[0:training_data_len]
test_data = data[training_data_len:]# 归一化处理
scaler = MinMaxScaler(feature_range=(0, 1))
scaled_train_data = scaler.fit_transform(train_data)
scaled_test_data = scaler.transform(test_data)# 指定LSTM模型参数
lstm_units = 50
input_seq_len = 60
output_seq_len = 30# 创建LSTM模型
model = Sequential()
model.add(LSTM(units=lstm_units, input_shape=(input_seq_len, 1)))
model.add(Dense(units=1))
model.compile(loss='mean_squared_error', optimizer='adam')
model.summary()
# 训练模型
history = model.fit(train_x, train_y, epochs=50, batch_size=32, validation_split=0.1, verbose=1)# 绘制训练和验证损失的曲线
plt.plot(history.history['loss'], label='Training Loss')
plt.plot(history.history['val_loss'], label='Validation Loss')
plt.legend()
plt.show()# 使用模型预测股票价格
predicted_data = predict_future(model, scaled_test_data, input_seq_len, output_seq_len)
predicted_data = scaler.inverse_transform(predicted_data.reshape((-1, output_seq_len)))
test_data = scaler.inverse_transform(test_y.reshape((-1, output_seq_len)))
# 绘制预测结果和实际结果的图表
plt.figure(figsize=(10, 6))
plt.plot(test_data, label='Actual')
plt.plot(predicted_data.flatten(), label='Predicted')
plt.legend()
plt.show()

4. 公式推导

LSTM模型中的关键部分是门控单元,它能够控制信息的流动,从而实现长期依赖关系的捕捉。门控单元由三个部分组成:遗忘门、输入门和输出门。

遗忘门用于控制前一时刻的记忆细胞中的信息是否需要被遗忘,其公式为:

ft=σ(Wf[ht−1,xt]+bf)f_t=\sigma(W_f[h_{t-1},x_t]+b_f)ft=σ(Wf[ht1,xt]+bf)

其中,ht−1h_{t-1}ht1为前一时刻的隐藏状态,xtx_txt为当前时刻的输入,WfW_fWfbfb_fbf为遗忘门的权重和偏置,σ\sigmaσ为sigmoid函数。

输入门用于控制当前时刻输入信息的权重,其公式为:

it=σ(Wi[ht−1,xt]+bi)i_t=\sigma(W_i[h_{t-1},x_t]+b_i)it=σ(Wi[ht1,xt]+bi)

其中,WiW_iWibib_ibi为输入门的权重和偏置。

记忆细胞的更新通过下面的公式实现:

Ct=ft⊙Ct−1+it⊙tanh⁡(Wc[ht−1,xt]+bc)C_t=f_t\odot C_{t-1}+i_t\odot \tanh(W_c[h_{t-1},x_t]+b_c)Ct=ftCt1+ittanh(Wc[ht1,xt]+bc)

其中,⊙\odot表示元素乘积,WcW_cWcbcb_cbc为更新记忆细胞的权重和偏置,tanh⁡\tanhtanh表示双曲正切函数。

输出门用于控制输出信息的权重,其公式为:

ot=σ(Wo[ht−1,xt]+bo)o_t=\sigma(W_o[h_{t-1},x_t]+b_o)ot=σ(Wo[ht1,xt]+bo)

hth_tht为当前时刻的隐藏状态,其计算公式为:

ht=ot⊙tanh⁡(Ct)h_t=o_t\odot \tanh(C_t)ht=ottanh(Ct)

最终的预测结果通过连接输出层实现。


❤️觉得内容不错的话,欢迎点赞收藏加关注😊😊😊,后续会继续输入更多优质内容❤️

👉有问题欢迎大家加关注私戳或者评论(包括但不限于NLP算法相关,linux学习相关,读研读博相关......)👈


http://www.ppmy.cn/news/30844.html

相关文章

【预告】ORACLE Unifier v22.12 虚拟机发布

引言 离ORACLE Primavera Unifier 最新系统 v22.12已过去了3个多月,应盆友需要,也为方便大家体验,我近日将构建最新的Unifier的虚拟环境,届时将分享给大家,最终可通过VMWare vsphere (esxi) / workstation 或Oracle …

文件预览kkFileView安装及使用

1 前言网页端一般会遇到各种文件,比如:txt、doc、docx、pdf、xml、xls、xlsx、ppt、pptx、zip、png、jpg等等。有时候我们不想要把文件下载下来,而是想在线打开文件预览 ,这个时候如果每一种格式都需要我们去写代码造轮子去实现预…

类和对象及其构造方法

类和对象 现实世界的事物由什么组成? 属性 行为 类也可以包含属性和行为,所以使用类描述现实世界事物是非常合适的类和对象的关系是什么? 类是程序中的“设计图纸” 对象是基于图纸生产的具体实体什么是面向对象编程? 面向对象编…

【C语言】详解静态变量static

关键字static 在C语言中:static是用来修饰变量和函数的static主要作用为:1. 修饰局部变量-静态局部变量 2. 修饰全局变量-静态全局变量3. 修饰函数-静态函数在讲解静态变量之前,我们应该了解静态变量和其他变量的区别: 修饰局部变量 //代码1 #include &l…

【打卡-Coggle竞赛学习2023年3月】对话意图识别

学习链接: https://coggle.club/blog/30days-of-ml-202303 ## Part1 内容介绍 本月竞赛学习将以对话意图识别展开,意图识别是指分析用户的核心需求,错误的识别几乎可以确定找不到能满足用户需求的内容,导致产生非常差的用户体验…

大数据面试集锦

一、Linux 1)常用的高级命令top iotop ps -ef df -h natstat jmap -heap tar rpm 2)查看进程 查看端口号 查看磁盘使用情况 查看某个进程内存ps -ef natstat df -h jmap -heap二、shell 1、用过哪些工具awk sed …

matlab基础到实战(1)

目录概述sin函数例子四则运算实数复数逻辑运算复数运算模幅角共轭向量二维向量定义序列生成向量向量索引方式加减乘除向量间运算加减乘法除法概述 MATLAB是美国MathWorks公司出品的商业数学软件,用于数据分析、无线通信、深度学习、图像处理与计算机视觉、信号处理…

详解Java8中如何通过方法引用获取属性名/::的使用

在我们开发过程中常常有一个需求,就是要知道实体类中Getter方法对应的属性名称(Field Name),例如实体类属性到数据库字段的映射,我们常常是硬编码指定 属性名,这种硬编码有两个缺点。 1、编码效率低&#x…