lstm实践

devtools/2024/12/22 18:38:09/

今年华为杯研究生数学建模的C题第四问用到了lstm,这里配合代码简要地讲一下。

数据类型

磁通密度是一个时序数据,包含了一个周期内的磁通密度变化,我们需要对它进行降维,但PCA是不合适的,因为PCA主要关注数据的方差,无法有效捕捉周期性数据的重要特征,而磁通密度是周期性变化的。

在自然语言处理领域中,LSTM可以捕获序列内部元素之间的关联性,并且其隐藏层可以包含前序序列的信息。最后一层的隐藏层就包含了整个序列的信息,所以我们可以将最后一层的隐藏层作为降维后的向量。

我们选择LSTM对1024 维的磁通密度进行降维,具体做法是:训练时对一个周期进行切片,使用LSTM预测切片的下一时刻的磁通密度;降维时使用整个周期,获取最后一 层的hidden state作为该样本的磁通密度特征。

代码

1.数据处理

python">import pandas as pd
import torch as pt
import os
os.chdir('/home/burger/math/')df1 = pd.read_excel('./data/附件一(训练集).xlsx', sheet_name='材料1')
df2 = pd.read_excel('./data/附件一(训练集).xlsx', sheet_name='材料2')
df3 = pd.read_excel('./data/附件一(训练集).xlsx', sheet_name='材料3')
df4 = pd.read_excel('./data/附件一(训练集).xlsx', sheet_name='材料4')collom_name = [i for i in range(1,1024)]
B1 = df1[['0(磁通密度B,T)']+collom_name]
B2 = df2[['0(磁通密度,T)']+collom_name]
B3 = df3[['0(磁通密度B,T)']+collom_name]
B4 = df4[['0(磁通密度B,T)']+collom_name]
print(B1.head())B1_t = pt.tensor(B1.values)
B2_t = pt.tensor(B2.values)
B3_t = pt.tensor(B3.values)
B4_t = pt.tensor(B4.values)
B = pt.cat((B1_t, B2_t, B3_t, B4_t), 0)
print(B.shape)def create_dataset(data, time_step=64):  x, y = [], []  for i in range(0, data.shape[1] - time_step, 32):  a = data[:, i:(i + time_step)]  x.append(a)  y.append(data[:, i + time_step])  return pt.concat(x).float(), pt.concat(y).float()X, Y = create_dataset(B)
print(X.shape, Y.shape)
X, Y = X.unsqueeze(-1), Y.unsqueeze(-1)
print(X.shape, Y.shape)

 2.模型定义

python">import torch.nn as nn 
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
from tqdm import tqdm# LSTM模型定义
class LSTMModel(nn.Module):def __init__(self, input_size, hidden_size, output_size, num_layers=1):super(LSTMModel, self).__init__()self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True)self.fc = nn.Linear(hidden_size, output_size)def forward(self, x):lstm_out, _ = self.lstm(x)out = self.fc(lstm_out[:, -1, :])  # 只取最后一个时间步的输出return outdef embedding(self, x):_, hid_cell = self.lstm(x)return hid_cell[0]

3.训练

python">input_size = 1
hidden_size = 1
output_size = 1
num_layers = 1
num_epochs = 5
batch_size = 2048
gpu = 6
train_dataset = TensorDataset(X.to(gpu), Y.to(gpu))
train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)model = LSTMModel(input_size, hidden_size, output_size, num_layers).to(gpu)
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)def train(model, dataloader, criterion, optimizer, num_epochs):model.train()for epoch in range(num_epochs):for inputs, labels in tqdm(dataloader, unit='batch'):optimizer.zero_grad()outputs = model(inputs)loss = criterion(outputs, labels)loss.backward()optimizer.step()print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}')
train(model, train_loader, criterion, optimizer, num_epochs)

4.降维

python">df_san = pd.read_excel('./data/附件三(测试集).xlsx')
B_san = df_san[['0(磁通密度B,T)']+collom_name]
B_san_t = pt.tensor(B_san.values)
B_san_t = B_san_t.unsqueeze(-1).float().to(gpu)
print(B_san_t.shape)emb_dataset = TensorDataset(B_san_t)
emb_loader = DataLoader(dataset=emb_dataset, batch_size=400, shuffle=False)def embedding(model, dataloader):model.eval()embeddings = []for inputs in tqdm(dataloader, unit='batch'):outputs = model.embedding(inputs[0])embeddings.append(outputs)return pt.cat(embeddings).cpu().detach().numpy()embeddings = embedding(model, emb_loader)
print(embeddings.shape)
embeddings = embeddings.reshape(-1)
print(embeddings.shape)embeddings_df = pd.DataFrame({'磁通密度编码': embeddings,
})
embeddings_df.to_excel('./data/磁通密度编码1.xlsx', index=False)

有问题欢迎在评论区讨论!


http://www.ppmy.cn/devtools/120121.html

相关文章

神点SAAS云财务系统/多账套/前后端全开源

>>>系统简述: 神点SAAS云财务软件开源版,包含账套、凭证字、科目、期初、币别、账簿、报表、凭证、结账等功能。 神点云财务系统,餐饮行业财务软件、微服务架构财务软件、开源云财务软件、Java全开源财务软件优选! >…

学习鸿蒙Harmong基础(二)

1.类声明和使用 class Perpon { name : string "小赵"; age : number 24; isShow :boolean true; // 构造函数 constructor(name:string,age:number,isShow:boolean){ this.name name; this.age age; this.isShow isShow } puperyInfo(){ if (this.isShow) { …

Qt界面优化——绘图API

文章目录 绘图核心API绘制各种形状绘制线段绘制矩形绘制圆形绘制文本设置画笔设置画刷 绘制图片 绘图核心API Qt的各种控件,本质上都是画出来的,这不过这些都是提前画好了,我们拿过来直接使用即可。 实际开发中,可能现有控件无法…

C++入门(有C语言基础)

string类 string类初始化的方式大概有以下几种: string str1;string str2 "hello str2";string str3("hello str3");string str4(5, B);string str5[3] {"Xiaomi", "BYD", "XPeng"};string str6 str5[2];str…

基于J2EE技术的高校社团综合服务系统

目录 毕设制作流程功能和技术介绍系统实现截图开发核心技术介绍:使用说明开发步骤编译运行代码执行流程核心代码部分展示可行性分析软件测试详细视频演示源码获取 毕设制作流程 (1)与指导老师确定系统主要功能; (2&am…

【Python】Uvicorn:Python 异步 ASGI 服务器详解

Uvicorn 是一个为 Python 设计的 ASGI(异步服务器网关接口)Web 服务器。它填补了 Python 在异步框架中缺乏一个最小化低层次服务器/应用接口的空白。Uvicorn 支持 HTTP/1.1 和 WebSockets,是构建现代异步Web应用的强大工具。 ⭕️宇宙起点 &a…

深入理解同步和异步与reactor和proactor模式

在现代网络编程中,I/O 设计模式对于提高性能和资源利用率至关重要。本文将探讨两种主要的网络 I/O 设计模式:同步 I/O 和异步 I/O,以及它们的实现方式。 同步 I/O 同步 I/O 模式要求用户通过系统调用函数,如 read(), write(), c…

javacv FFmpegFrameGrabber 阻塞重连解决方法汇总

JavaCV中FrameGrabber类可以连接直播流地址, 进行解码, 获取Frame帧信息, 常用方式如下 FrameGrabber grabber new FrameGrabber("rtsp:/192.168.0.0"); while(true) {Frame frame grabber.grabImage();// ... } 在如上代码中, 若连接地址网络不通, 或者连接超时…