RNN股票预测(Pytorch版)

ops/2024/9/24 8:53:44/

任务:基于zgpa_train.csv数据,建立RNN模型,预测股价
1.完成数据预处理,将序列数据转化为可用于RNN输入的数据
2.对新数据zgpa_test.csv进行预测,可视化结果
3.存储预测结果,并观察局部预测结果
备注:模型结构:单层RNN,输出有5个神经元,每次使用前8个数据预测第9个数据
参考视频:吹爆!3小时搞懂!【RNN循环神经网络+时间序列LSTM深度学习模型】学不会UP主下跪!
up主用的Keras,自己用Pytorch尝试了一下,代码如下:

import pandas as pd
import numpy as np
import torch
from torch import nn
from matplotlib import pyplot as plt
data = pd.read_csv('zgpa_train.csv')
# loc 通过行索引 “Index” 中的具体值来取行数据
# 取出开盘价
price = data.loc[:,'close']# 归一化
price_norm = price/max(price)
# 开盘价折线图
# fig1 = plt.figure(figsize=(10, 6))
# plt.plot(price)
# plt.title('close price')
# plt.xlabel('time')
# plt.ylabel('price')
# plt.show()# 提取数据 每次使用前8个数据来预测第九个数据
def extract_data(data, time_step):x = []y = []for i in range(len(data)- time_step):x.append([a for a in data[i:i+time_step]])y.append(data[i + time_step])x = np.array(x)x = x.reshape(x.shape[0], x.shape[1], 1)x = torch.tensor(x, dtype=torch.float32)y = torch.tensor(y, dtype=torch.float32)return x, y
time_step = 8
x, y = extract_data(price_norm,time_step)
# print(x)
# print(y)
class RNN(nn.Module):def __init__(self, input_size, hidden_size, output_size, num_layers):super(RNN,self).__init__()self.rnn = nn.RNN(input_size, hidden_size, num_layers, batch_first = True)self.fc = nn.Linear(hidden_size, output_size)def forward(self, x):out, _ = self.rnn(x)# print(out)out = self.fc(out[:, -1, :])out = out.squeeze(1)return out
# 定义模型参数
input_size = 1 # 输入特征的维度
hidden_size = 64 # 隐藏层的维度
output_size = 1 # 输出特征的维度
num_layers = 1 # RNN的层数# 创建模型
model = RNN(input_size, hidden_size, output_size, num_layers)# 定义损失函数和优化器
criterion = nn.MSELoss(reduction='mean')
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)# 训练模型
epochs = 200
for epoch in range(epochs):optimizer.zero_grad()# outputs = model(x.unsqueeze(2))outputs = model(x)loss = criterion(outputs, y)loss.backward()optimizer.step()if (epoch + 1) % 10 == 0:print(f'Epoch [{epoch + 1}/{epochs}], Loss: {loss.item():.4f}')
# 进行预测 数据很少这里就不先保存模型再预测了
model.eval()
with torch.no_grad():y_train_predict = model(x) * max(price)
y_train = [i * max(price) for i in y]
# print(y_train_predict)
y_train_predict = y_train_predict.cpu().numpy()
y_train = np.array(y_train)
fig2 = plt.figure(figsize=(10, 6))
plt.plot(y_train_predict, label='Predicted', color='blue')
plt.plot(y_train, label='True', color='red', alpha=0.6)
plt.title('Predicted vs True Values')
plt.xlabel('time')
plt.ylabel('price')
plt.legend()
plt.show()# 测试集
data_test = pd.read_csv('zgpa_test.csv')
price_test = data_test.loc[:,'close']
price_test_norm = price_test/max(price)
x_test,y_test = extract_data(price_test_norm,time_step)
with torch.no_grad():y_test_predict = model(x_test) * max(price)
y_test = [i * max(price) for i in y_test]
# print(y_train_predict)
y_test_predict = y_test_predict.cpu().numpy()
y_test = np.array(y_test)
fig3 = plt.figure(figsize=(10, 6))
plt.plot(y_test_predict, label='Predicted', color='blue')
plt.plot(y_test, label='True', color='red', alpha=0.6)
plt.title('Predicted vs True Values (Test Set)')
plt.xlabel('time')
plt.ylabel('price')
plt.legend()
plt.show()# 存储数据
result_y_test = np.array(y_test).reshape(-1, 1) # 若干行,1列
result_y_test_predict = y_test_predict.reshape(-1, 1)
print(result_y_test.shape, result_y_test_predict.shape)
result = np.concatenate((result_y_test, result_y_test_predict), axis=1)
print(result.shape)
result = pd.DataFrame(result, columns=['real_price_test', 'predict_price_test'])
result.to_csv('zgpa_predict_test.csv')

http://www.ppmy.cn/ops/115211.html

相关文章

ant vue3 datePicker默认显示英文

改前: 改后: 处理方法: 在App.vue页加上以下导入即可 import dayjs from dayjs; import dayjs/locale/zh-cn dayjs.locale(zh-cn); 如图:

Python3 爬虫教程 - Web 网页基础

Web网页基础 1,网页的组成HTMLcssJavaScript2,网页的结构 3,节点树及节点间的关系4,选择器开头代表选择 id,其后紧跟 id 的名称。如:div 节点的 id 为 container,那么就可以表示为 #container 1…

Linux学习笔记13---GPIO 中断实验

中断系统是一个处理器重要的组成部分,中断系统极大的提高了 CPU 的执行效率,本章会将 I.MX6U 的一个 IO 作为输入中断,借此来讲解如何对 I.MX6U 的中断系统进行编程。 GIC 控制器简介 1、GIC 控制器总览 I.MX6U(Cortex-A)的中断控制器…

如何使用ssm实现基于Javaweb的网上花店系统的设计与实现

TOC ssm653基于Javaweb的网上花店系统的设计与实现jsp 研究背景 自计算机发展以来给人们的生活带来了改变。第一代计算机为1946年美国设计,最开始用于复杂的科学计算,占地面积、开机时间要求都非常高,经过数十几的改变计算机技术才发展到今…

zookeeper

目录 zookeeper概述 zookeeper工作机制 1. 数据模型 2. 会话管理 3. Watcher 机制 4. Leader 选举 5. 一致性协议(ZAB 协议) 6. 读写请求处理 zookeeper应用场景 统一命名服务 统一配置管理 统一集群管理 服务器动态上下线 软负载均衡 zoo…

vim 操作一列数字

一列数字从 9 到 23,想要将它们都减去 9 使用宏: a. 将光标移动到第一个数字 b. 按 qa 开始录制宏 c. 按 9 然后按 Ctrl-X (这会减去 9) d. 按 j 移动到下一行 e. 按 q 停止录制 f. 使用 a 重复宏,或 100a 重复多次 …

C++Thread封装

实现一个C的对pthread的封装基类,来实现对线程的启动,分离,等待结束以及取消操作,可以在派生类中定义run函数来实现线程的具体操作 定义头文件 // // Created by crab on 2024/9/24. //#ifndef THREAD_H #define THREAD_H#inclu…

基于飞腾平台的OpenCV的编译与安装

【写在前面】 飞腾开发者平台是基于飞腾自身强大的技术基础和开放能力,聚合行业内优秀资源而打造的。该平台覆盖了操作系统、算法、数据库、安全、平台工具、虚拟化、存储、网络、固件等多个前沿技术领域,包含了应用使能套件、软件仓库、软件支持、软件适…