RLHF 的启示:微调 LSTM 能更好预测股票?

server/2024/10/20 10:14:50/

作者:老余捞鱼

原创不易,转载请标明出处及原作者。

写在前面的话:
       
在财务预测领域,准确预测股票价格是一项具有挑战性但至关重要的任务。传统方法通常难以应对股票市场固有的波动性和复杂性。这篇文章介绍了一种创新方法,该方法将长短期记忆 (LSTM) 网络与基于评分的微调机制相结合,以增强股票价格预测。我们将以 Reliance Industries Limited 的股票作为我们的案例研究,展示这种方法如何潜在地提高预测准确性。

一、核心理念

       受 RLHF 的启发,我们尝试在时间序列预测中应用相同的概念,RLHF的概念因为ChatGPT的出现,可能第一次出现在大多数人的眼里,RLHF 是 "Reinforcement Learning from Human Feedback" 的缩写,这是一种结合了强化学习和人类反馈的机器学习方法。在这种方法中,人工智能(AI)系统通过执行任务并接收人类评估者对其行为的反馈来学习。这种方法特别适用于那些难以用传统奖励函数明确定义任务成功与否的情况。回到正题,我们的方法围绕三个关键组成部分:

1. 用于初始股票价格预测的LSTM模型
2.评估这些预测质量的评分模型
3.使用评分模型的输出来优化 LSTM 性能的微调过程

       通过集成这些组件,我们的目标是创建一个更具适应性和准确性的预测系统,从而更好地捕捉股价变动的细微差别。

二、架构概述

1. LSTM 模型:
       我们系统的核心是 LSTM 神经网络。LSTM 特别适合于股票价格等时间序列数据,因为它们能够捕获数据中的长期依赖关系。我们的 LSTM 模型将一系列历史股票价格作为输入,并预测序列中的下一个价格。

2. 评分模型:
       评分模型是一个单独的神经网络,旨在评估 LSTM 预测的质量。它采用原始价格序列和 LSTM 的预测作为输入,输出一个表示 LSTM 预测预测准确性的分数。

3. 微调机制:
       该组件使用评分模型生成的分数来调整 LSTM 的训练过程。在微调过程中,从评分模型获得较高分数的预测会得到更大的权重,从而鼓励 LSTM 学习模式,从而获得更准确的预测。

三、工作流程

1. 数据准备:
       我们首先使用 yfinance 库获取 Reliance Industries Limited 的历史股票价格数据。然后,这些数据被预处理并拆分为适合 LSTM 训练的序列。

2. 初始 LSTM 训练:
       LSTM 模型在部分历史数据上进行训练。这为我们提供了一个能够做出合理股票价格预测的基准模型。

3. 评分模型训练:
       我们使用另一部分数据来训练评分模型。该模型通过将 LSTM 的预测与实际股票价格进行比较来学习评估 LSTM 预测的质量。

4. 微调过程:
       使用数据的第三部分,我们对 LSTM 模型进行微调。在此过程中,我们使用评分模型来评估每个预测。LSTM 的学习率会根据这些分数进行调整,使其能够更专注于改进评分模型认为不太准确的预测。

5. 评估:
       最后,我们在测试集上评估原始 LSTM 和微调后的 LSTM 的性能,比较它们的预测以评估微调方法的有效性。

四、代码实现

       让我们将代码分解为多个部分并详细解释每个部分。

1. 导入库并设置环境

import yfinance as yf
import numpy as np
import pandas as pd
from sklearn.preprocessing import MinMaxScaler
import torch.nn
import torch.nn import
torch.optim
as optim from torch.utils.data import TensorDataset, DataLoader
import matplotlib.pyplot as plt
device = torch.device(“cuda” if torch.cuda.is_available() else “cpu”)
print(f“Using device:{device}”)

       此部分导入所有必要的库。我们使用 yfinance 来获取股票数据,使用 numpy 和 pandas 进行数据操作,使用 sklearn 进行数据预处理,使用 torch 构建和训练神经网络,使用 matplotlib 进行可视化。我们还设置了 PyTorch 将用于计算的设备(CPU 或 GPU)。

2. 数据获取和预处理

reliance = yf.Ticker(“RELIANCE.NS”)
data = reliance.history(period=”max”)[‘Close’].values.reshape(-1, 1)scaler = MinMaxScaler(feature_range=(0, 1))
data_normalized = scaler.fit_transform(data)def create_sequences(data, seq_length):
sequences = []
targets = []
for i in range(len(data) — seq_length):
seq = data[i:i+seq_length]
target = data[i+seq_length]
sequences.append(seq)
targets.append(target)
return np.array(sequences), np.array(targets)seq_length = 60 # 60 days of historical data
X, y = create_sequences(data_normalized, seq_length)

       在这里,我们获取 Reliance Industries Limited 股票的历史收盘价。我们使用 MinMaxScaler 对数据进行归一化,以确保所有值都在 0 到 1 之间,这有助于训练神经网络。

       “create_sequences”功能至关重要。它将我们的时间序列数据转换为适合 LSTM 训练的格式。对于每个数据点,它会创建一个前 60 天 (seq_length) 的序列作为输入,并以第二天的价格为目标。

3. 数据切分

lstm_split = int(0.5 * len(X))
scoring_split = int(0.75 * len(X))X_lstm, y_lstm = X[:lstm_split], y[:lstm_split]
X_scoring, y_scoring = X[lstm_split:scoring_split], y[lstm_split:scoring_split]
X_finetuning, y_finetuning = X[scoring_split:], y[scoring_split:]lstm_train_split = int(0.8 * len(X_lstm))
X_lstm_train, y_lstm_train = X_lstm[:lstm_train_split], y_lstm[:lstm_train_split]
X_lstm_test, y_lstm_test = X_lstm[lstm_train_split:], y_lstm[lstm_train_split:]

       我们将数据分为三个主要部分:

1. LSTM 训练和测试
2.评分模型训练
3.微调

       这确保了我们流程的每个阶段都使用单独的数据,防止数据泄露,并对我们的方法进行公平评估。

4. LSTM 模型定义

class LSTMModel(nn.Module):def __init__(self, input_size=1, hidden_size=50, num_layers=2, output_s

http://www.ppmy.cn/server/127363.html

相关文章

忘记 MySQL 密码怎么办:破解 root 账户密码

忘记 MySQL 密码怎么办:破解 root 账户密码 目录 忘记 MySQL 密码怎么办:破解 root 账户密码1、修改 MySQL 配置文件2、不使用密码登录 MySQL3、重置 root 用户密码4、修改 MySQL 配置文件并重启 MySQL 服务5、使用新密码登录 MySQL 如果忘记密码导致无法…

【视频目标分割-2024CVPR】Putting the Object Back into Video Object Segmentation

Cutie 系列文章目录1 摘要2 引言2.1背景和难点2.2 解决方案2.3 成果 3 相关方法3.1 基于记忆的VOS3.2对象级推理3.3 自动视频分割 4 工作方法4.1 overview4.2 对象变换器4.2.1 overview4.2.2 Foreground-Background Masked Attention4.2.3 Positional Embeddings 4.3 Object Me…

C++语言学习(6):《C++程序设计原理与实践》第一章笔记

最近在看 C之父 BS 的 《C程序设计原理与实践》, 记录下。 目标读者 本书适合于哪些从未有过编程经验但愿意努力学习程序设计的初学者,它能帮助你理解使用C语言进行程序设计的基本原理并获得实践技巧。 作为大学课程大概需要15小时/周 * 14周 210 小时。 本书不是…

Redis设计与实现 学习笔记 第五章 跳跃表

跳跃表(skiplist)是一种有序的数据结构,它通过在每个节点中维持多个指向其他节点的指针,达到快速访问节点的目的。 跳跃表支持平均O(logN)、最坏O(N)复杂度的节点查找,还可以通过顺序性操作来批量处理节点。 在大部分…

卡码网KamaCoder 53. 寻宝

题目来源:53. 寻宝(第七期模拟笔试) C题解(来源代码随想录):最小生成树 prim prim三部曲 第一步,选距离生成树最近节点第二步,最近节点加入生成树第三步,更新非生成树节…

[Linux]:线程(二)

✨✨ 欢迎大家来到贝蒂大讲堂✨✨ 🎈🎈养成好习惯,先赞后看哦~🎈🎈 所属专栏:Linux学习 贝蒂的主页:Betty’s blog 与Windows环境不同,我们在linux环境下需要通过指令进行各操作&…

Socket套接字(客户端,服务端)和IO多路复用

Socket套接字(客户端,服务端) 目录 socket是什么一、在客户端1. 创建套接字2. 设置服务器地址3. 连接到服务器4. 发送数据5. 接收数据6. 关闭连接 二、内核态与用户态切换三、系统调用与上下文切换的关系四、在服务端1. 创建 Socket (用户态…

两名大学生利用Meta的智能眼镜展示了一项令人震惊的技术,能够实时“人肉”他人的身份信息

每周跟踪AI热点新闻动向和震撼发展 想要探索生成式人工智能的前沿进展吗?订阅我们的简报,深入解析最新的技术突破、实际应用案例和未来的趋势。与全球数同行一同,从行业内部的深度分析和实用指南中受益。不要错过这个机会,成为AI领…