使用Python和Keras库实现基于双向门控循环单元(BiGRU)模型进行深度学习序列预测的示例

embedded/2025/3/16 9:35:22/

下面是一个使用Python和Keras库实现基于双向门控循环单元(BiGRU)模型进行深度学习序列预测的示例。我们将使用一个简单的正弦波时间序列数据来演示该过程。

步骤:

  1. 数据生成:生成正弦波时间序列数据。
  2. 数据预处理:将数据划分为训练集和测试集,并将其转换为适合模型输入的格式。
  3. 模型构建:构建BiGRU模型。
  4. 模型训练:使用训练数据对模型进行训练。
  5. 模型评估:使用测试数据评估模型的性能。
  6. 预测可视化:可视化模型的预测结果。

代码实现:

python">import numpy as np
import matplotlib.pyplot as plt
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Bidirectional, GRU, Dense
from sklearn.preprocessing import MinMaxScaler# 1. 数据生成
def generate_sine_wave(seq_length):time = np.linspace(0, 2 * np.pi, seq_length)sine_wave = np.sin(time)return sine_wave# 2. 数据预处理
def prepare_data(data, look_back):X, y = [], []for i in range(len(data) - look_back):X.append(data[i:(i + look_back)])y.append(data[i + look_back])X = np.array(X)y = np.array(y)# 调整输入数据的形状以适应BiGRU模型 [样本数, 时间步长, 特征数]X = np.reshape(X, (X.shape[0], X.shape[1], 1))return X, y# 生成正弦波数据
seq_length = 1000
sine_wave = generate_sine_wave(seq_length)# 数据归一化
scaler = MinMaxScaler(feature_range=(0, 1))
sine_wave = scaler.fit_transform(sine_wave.reshape(-1, 1)).flatten()# 划分训练集和测试集
train_size = int(len(sine_wave) * 0.8)
train_data = sine_wave[:train_size]
test_data = sine_wave[train_size:]# 准备训练数据和测试数据
look_back = 10
X_train, y_train = prepare_data(train_data, look_back)
X_test, y_test = prepare_data(test_data, look_back)# 3. 模型构建
model = Sequential()
model.add(Bidirectional(GRU(50, activation='relu', input_shape=(look_back, 1))))
model.add(Dense(1))
model.compile(optimizer='adam', loss='mse')# 4. 模型训练
model.fit(X_train, y_train, epochs=50, batch_size=32, verbose=1)# 5. 模型评估
test_loss = model.evaluate(X_test, y_test, verbose=0)
print(f'Test Loss: {test_loss}')# 6. 预测可视化
train_predict = model.predict(X_train)
test_predict = model.predict(X_test)# 反归一化
train_predict = scaler.inverse_transform(train_predict)
y_train = scaler.inverse_transform(y_train.reshape(-1, 1))
test_predict = scaler.inverse_transform(test_predict)
y_test = scaler.inverse_transform(y_test.reshape(-1, 1))# 绘制训练集预测结果
plt.figure(figsize=(12, 6))
plt.plot(y_train, label='True Train Values')
plt.plot(train_predict, label='Predicted Train Values')
plt.title('Train Data Prediction')
plt.xlabel('Time Step')
plt.ylabel('Value')
plt.legend()
plt.show()# 绘制测试集预测结果
plt.figure(figsize=(12, 6))
plt.plot(y_test, label='True Test Values')
plt.plot(test_predict, label='Predicted Test Values')
plt.title('Test Data Prediction')
plt.xlabel('Time Step')
plt.ylabel('Value')
plt.legend()
plt.show()

代码解释:

  1. 数据生成generate_sine_wave 函数生成一个正弦波时间序列数据。
  2. 数据预处理prepare_data 函数将时间序列数据转换为适合BiGRU模型输入的格式。同时,使用 MinMaxScaler 对数据进行归一化处理。
  3. 模型构建:使用Keras的 Sequential 模型构建一个简单的BiGRU模型,包含一个双向GRU层和一个全连接层。
  4. 模型训练:使用 fit 方法对模型进行训练,设置训练轮数为50,批次大小为32。
  5. 模型评估:使用 evaluate 方法评估模型在测试集上的性能。
  6. 预测可视化:使用 matplotlib 库绘制训练集和测试集的预测结果。

通过以上步骤,你可以使用BiGRU模型进行序列预测。


http://www.ppmy.cn/embedded/173027.html

相关文章

MAC地址IP地址如何转换?

0. 运维干货分享 软考系统架构设计师三科备考经验附学习资料CKA认证学习资料分享信息安全管理体系(ISMS)制度模板分享免费文档翻译工具(支持word、pdf、ppt、excel)PuTTY中文版安装包MobaXterm中文版安装包pinginfoview网络诊断工具中文版 在计算机网络…

1.排序算法(学习自用)

1.冒泡排序 算法步骤 相邻的元素之间对比,每次早出最大值或最小值放到最后或前面,所以形象的称为冒泡。 特点 n个数排序则进行n轮,每轮比较n-i次。所以时间复杂度为O(n^2),空间复杂度为O(1),该排序算法稳定。 代码…

新手村:统计量均值、中位数、标准差、四分位数

新手村:统计量均值、中位数、标准差、四分位数 统计量定义与讲解 统计量定义计算公式示例说明均值数据集中的所有数值之和除以数值的个数。 Mean ∑ i 1 n x i n \text{Mean} \frac{\sum_{i1}^{n} x_i}{n} Meann∑i1n​xi​​对于数据集 [1, 2, 3, 4, 5]&#x…

【17-3】Twitter评论情绪分类实战

139-Twitter评论情绪基础RNN模型分类 143-LSTM文本分类模型 【参考文档】17-3Twitter评论情绪分类.ipynb 【导出代码】 # %% [markdown] # # 139-Twitter评论情绪分类# %% [markdown] # ## 数据读取处理# %% import torch import torchtext import torch.nn as nn import t…

ARM64 架构地址空间分配深度解析

一、寻址空间选择的技术逻辑(基于 ARMv8 架构) 地址空间截断的工程实现(LPAE 技术) 在计算架构设计中,ARM64架构选择使用48位/52位虚拟地址空间而非完整的64位寻址,这一决策体现了硬件设计者在性能、功耗…

【A2DP】深入解读A2DP中通用访问配置文件(GAP)的互操作性要求

目录 一、模式支持要求 1.1 发现模式 1.2 连接模式 1.3 绑定模式 1.4 模式间依赖关系总结 1.5 注意事项 1.6 协议设计深层逻辑 二、安全机制(Security Aspects) 三、空闲模式操作(Idle Mode Procedures) 3.1 支持要求 …

Python 逆向工程:2025 年能破解什么?

有没有想过在复杂的软件上扭转局面?到 2025 年,Python 逆向工程不仅仅是黑客的游戏,它是开发人员、安全专业人员和好奇心强的人解开编译代码背后秘密的强大方法。无论您是在剖析恶意软件、分析 Python 应用程序的工作原理,还是学习…

多线程到底重不重要?

我们先说一下为什么要讲多线程和高并发? 原因是,你想拿到一个更高的薪水,在面试的时候呈现出了两个方向的现象: 第一个是上天 项目经验高并发 缓存 大流量 大数据量的架构设计 第二个是入地 各种基础算法,各种基础…