pytorch实现循环神经网络

server/2025/2/4 8:17:30/

 人工智能例子汇总:AI常见的算法和例子-CSDN博客 

PyTorch 提供三种主要的 RNN 变体:

  • nn.RNN:最基本的循环神经网络,适用于短时依赖任务。
  • nn.LSTM:长短时记忆网络,适用于长序列数据,能有效解决梯度消失问题。
  • nn.GRU:门控循环单元,比 LSTM 计算更高效,适用于大部分任务。
网络类型优势适用场景
RNN计算简单,适用于短时序列语音、文本处理(短序列)
LSTM适用于长序列,能记忆长期信息机器翻译、语音识别、股票预测
GRU比 LSTM 计算更高效,效果相似语音处理、文本生成

例子:

import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt# 1. 生成正弦波数据(仅使用 PyTorch)
def generate_sine_wave(seq_length=10, num_samples=1000):x = torch.linspace(0, 100, num_samples)  # 生成 1000 个等间距数据点y = torch.sin(x)  # 计算正弦值X_data, Y_data = [], []for i in range(len(y) - seq_length):X_data.append(y[i:i + seq_length].unsqueeze(-1))  # 过去 seq_length 作为输入Y_data.append(y[i + seq_length])  # 预测下一个点return torch.stack(X_data), torch.tensor(Y_data).unsqueeze(-1)# 生成数据
seq_length = 10  # 序列长度
X, Y = generate_sine_wave(seq_length)# 划分训练集和测试集
train_size = int(0.8 * len(X))
X_train, X_test = X[:train_size], X[train_size:]
Y_train, Y_test = Y[:train_size], Y[train_size:]# 2. 定义 RNN 模型
class SimpleRNN(nn.Module):def __init__(self, input_size, hidden_size, output_size, num_layers=1):super(SimpleRNN, self).__init__()self.hidden_size = hidden_sizeself.num_layers = num_layersself.rnn = nn.RNN(input_size, hidden_size, num_layers, batch_first=True)self.fc = nn.Linear(hidden_size, output_size)def forward(self, x):h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size)  # 初始化隐藏状态out, _ = self.rnn(x, h0)out = self.fc(out[:, -1, :])  # 取最后一个时间步的输出return out# 3. 训练模型
# 超参数
input_size = 1
hidden_size = 32
output_size = 1
num_layers = 1
num_epochs = 100
learning_rate = 0.001# 初始化模型
model = SimpleRNN(input_size, hidden_size, output_size, num_layers)
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)# 训练
for epoch in range(num_epochs):model.train()optimizer.zero_grad()outputs = model(X_train)loss = criterion(outputs, Y_train)loss.backward()optimizer.step()if (epoch + 1) % 10 == 0:print(f'Epoch [{epoch + 1}/{num_epochs}], Loss: {loss.item():.4f}')# 4. 评估与绘图
model.eval()
with torch.no_grad():predictions = model(X_test)# 画图
plt.figure(figsize=(10, 5))
plt.plot(Y_test.numpy(), label="Real Data")
plt.plot(predictions.numpy(), label="Predicted Data")
plt.legend()
plt.title("RNN Sine Wave Prediction")
plt.show()

代码解析

数据生成

  • torch.linspace(0, 100, num_samples) 生成 1000 个均匀分布的数据点。
  • torch.sin(x) 计算正弦值,形成时间序列数据。
  • X过去 10 个时间步的数据,Y下一个时间步的预测目标

构建 RNN

  • nn.RNN(input_size, hidden_size, num_layers, batch_first=True) 定义循环神经网络
    • input_size=1:每个时间步只有一个输入值(正弦波)。
    • hidden_size=32:隐藏层神经元数目。
    • num_layers=1:单层 RNN。
  • self.fc = nn.Linear(hidden_size, output_size) 负责最终输出。

训练

  • 使用 MSELoss(均方误差损失) 计算预测值与真实值的误差。
  • 使用 Adam 优化器 更新模型参数。
  • 每 10 个 epoch 输出一次损失 loss

测试 & 绘图

  • 关闭梯度计算 (torch.no_grad()),执行前向传播预测测试数据。
  • Matplotlib 绘制预测曲线与真实曲线。

运行效果

如果训练成功,预测曲线(橙色)应该与真实曲线(蓝色)非常接近


http://www.ppmy.cn/server/164837.html

相关文章

【Redis_2】短信登录

一、基于Session实现登录 RegexUtils:是定义的关于一些格式的正则表达式的工具箱 package com.hmdp.utils;import cn.hutool.core.util.StrUtil;public class RegexUtils {/*** 是否是无效手机格式* param phone 要校验的手机号* return true:符合,false&#xff…

对顾客行为的数据分析:融入2+1链动模式、AI智能名片与S2B2C商城小程序的新视角

摘要:随着互联网技术的飞速发展,企业与顾客之间的交互方式变得日益多样化,移动设备、社交媒体、门店、电子商务网站等交互点应运而生。这些交互点不仅为顾客提供了便捷的服务体验,同时也为企业积累了大量的顾客行为数据。本文旨在…

【含文档+PPT+源码】基于大数据的交通流量预测系统

项目介绍 本课程演示的是一款基于Python的图书管理系统的设计与实现,主要针对计算机相关专业的正在做毕设的学生与需要项目实战练习的 Java 学习者。 包含:项目源码、项目文档、数据库脚本、软件工具等所有资料 带你从零开始部署运行本套系统 该项目附…

git基础使用--4---git分支和使用

文章目录 git基础使用--4---git分支和使用1. 按顺序看2. 什么是分支3. 分支的基本操作4. 分支的基本操作4.1 查看分支4.2 创建分支4.3 切换分支4.4 合并冲突 git基础使用–4—git分支和使用 1. 按顺序看 -git基础使用–1–版本控制的基本概念 -git基础使用–2–gti的基本概念…

python Flask-Redis 连接远程redis

当使用Flask-Redis连接远程Redis时,首先需要安装Flask-Redis库。可以通过以下命令进行安装: pip install Flask-Redis然后,你可以使用以下示例代码连接远程Redis: from flask import Flask from flask_redis import FlaskRedisa…

基于物联网技术的实时数据流可视化研究(论文+源码)

1系统方案设计 根据系统功能的设计要求,展开基于物联网技术的实时数据流可视化研究设计。如图2.1所示为系统总体设计框图,系统以STM32单片机做为主控制器,通过DHT11、MQ-2、光照传感器实现环境中温湿度、烟雾、光照强度数据的实时检测&#x…

Vue3 结合 .NetCore WebApi 前后端分离跨域请求简易实例

1、本地安装Vue3环境 参考:VUE3中文文档-快速上手 注意:初始安装vue时,需要安装router,否则后续也要安装 2、安装axios组件 比如:npm install axioslatest 或 pnpm install axioslatest 3、设置跨域请求代理 打开…

Java程序设计:掌握核心语法与经典案例

一、引言 Java作为一种广泛使用的编程语言,以其简洁、高效、跨平台等特性深受开发者喜爱。掌握Java的核心语法是成为一名优秀Java程序员的基础。本文将从Java的基本语法入手,逐步深入到经典案例的分析,帮助读者快速掌握Java程序设计的关键要…