一个超级简单的清晰的LSTM模型的例子

news/2025/1/17 11:13:47/

废话不多说,把代码贴上去,就可以运行。然后看注释,自己慢慢品,细细品。

import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt# 1. 生成时间序列数据,这里使用正弦函数模拟
def generate_time_series():time_steps = np.linspace(0, 10 * np.pi, 500, dtype=np.float32)data = np.sin(time_steps)data = np.expand_dims(data, axis=-1)return data# 2. 划分训练集和测试集
def prepare_data(data):train_data = data[:400]test_data = data[400:]return train_data, test_data# 3. 创建数据集
def create_dataset(data, time_steps):Xs, ys = [], []for i in range(len(data) - time_steps):v = data[i:(i + time_steps)]Xs.append(v)ys.append(data[i + time_steps])return np.array(Xs), np.array(ys)# 4. 定义 LSTM 模型
def build_model(input_shape):model = tf.keras.Sequential([tf.keras.layers.LSTM(units=50, return_sequences=True, input_shape=input_shape),tf.keras.layers.LSTM(units=50),tf.keras.layers.Dense(units=1)])model.compile(optimizer='adam', loss='mse')return model# 5. 训练模型
def train_model(model, X_train, y_train, epochs=20):history = model.fit(X_train, y_train,epochs=epochs,batch_size=32,validation_split=0.1,shuffle=False)return history# 6. 预测和可视化结果
def predict_and_visualize(model, X_train, y_train, X_test, y_test):train_predict = model.predict(X_train)test_predict = model.predict(X_test)plt.figure(figsize=(10, 6))plt.plot(y_train, label='True Train')plt.plot(train_predict, label='Predicted Train')plt.plot(range(len(y_train), len(y_train) + len(y_test)), y_test, label='True Test')plt.plot(range(len(y_train), len(y_train) + len(y_test)), test_predict, label='Predicted Test')plt.legend(loc='upper left')plt.show()def plot_loss(history):plt.plot(history.history['loss'], label='Training Loss')plt.plot(history.history['val_loss'], label='Validation Loss')plt.title('Model Loss')plt.ylabel('Loss')plt.xlabel('Epoch')plt.legend(loc='upper right')plt.show()if __name__ == "__main__":# 生成数据data = generate_time_series()train_data, test_data = prepare_data(data)time_steps = 10X_train, y_train = create_dataset(train_data, time_steps)X_test, y_test = create_dataset(test_data, time_steps)# 构建模型input_shape = (X_train.shape[1], X_train.shape[2])model = build_model(input_shape)# 训练模型history = train_model(model, X_train, y_train, epochs=20)# 显示训练的loss,val_lossplot_loss(history)# 预测和可视化predict_and_visualize(model, X_train, y_train, X_test, y_test)

http://www.ppmy.cn/news/1563870.html

相关文章

基础入门-数据不回显数据不出网出入站策略正反向连接反弹Shell外带延迟写入

知识点: 1、数据不回显原因和解决-带外延迟反弹写文件 2、数据不出网原因和解决-出入站策略正反向连接 一、演示案例-数据不回显-原因解决-反弹&带外&延迟&写文件 原因:代码层面函数调用问题,没有输出测试等 实战过程&#xf…

科技赋能:多功能气膜综合馆引领场馆新革命—轻空间

在现代体育场馆建设中,如何为运动员提供更佳的比赛环境,为观众营造更舒适的观赛体验,已成为场馆设计的关键课题。而多功能气膜综合馆以其独特的声学优化技术和卓越的场馆功能,成功突破了传统气膜场馆的局限,为运动体验…

vue 纯前端导出 Excel

方法一: 1、安装"file-saver" npm i -S file-saver xlsx 2、引入 在需要导出功能的 .vue 文件中引入 import FileSaver from "file-saver"; import XLSX from "xlsx"; 3、简单示例(复制即可食用)&#x…

【C语言4】数组:一维数组、二维数组、变长数组及数组的练习题

文章目录 前言一、数组的概念二、一维数组2.1. 数组的创建和初始化2.2. 数组的类型2.3. 一维数组的下标2.4. 数组元素的打印和输入2.5. 一维数组在内存中的存储2.6. sizeof 计算数组元素个数 三、二维数组3.1. 二维数组的概念3.1. 二维数组的创建与初始化3.2. 二维数组的下标3.…

Graylog采集MySQL慢日志实战

文章目录 前言一、MySQL慢日志0. 慢查询相关语句1. 检查MySQL是否开启慢日志及慢查询保存位置2. 检查慢查询阈值3. 未使用索引是否开启记录慢查询日志4. 查看mysql.slow_log表结构及字段含义5. 慢查询记录两种情况示例 二、graylog采集慢查询日志1. 采集思路2. 创建Sidecar配置…

【Web系列三十】MYSQL库表比对升级脚本

写在前面 随着软件的迭代开发,数据库表有变动是常有的事,如果没有在开发时记录变更情况的话。对于线上生产环境下的MYSQL库表升级就会比较麻烦。 因此本文主要提供了一个脚本,方便比对新旧数据库的sql文件,从而自动生成用户升级的…

Cyberchef开发operation操作之-node开发环境搭建

本文介绍一下Cyberchef开发operation操作环境的搭建工作,为后续的Cyberchef开发operation操作提供开发环境基础,这里。该篇作为我的专栏《Cyberchef 从入门到精通教程》中的一篇,详见这里。 Linux环境 由于cyberchef只支持Linux和MAC的开发…

Python 替换excel 单元格内容

要在Python中替换Excel单元格的内容,你可以使用openpyxl库。openpyxl是一个用于读写Excel 2010 xlsx/xlsm/xltx/xltm文件的库。 安装openpyxl 首先,你需要安装openpyxl库。如果还没有安装,可以使用pip进行安装: pip install ope…