pytorch实现长短期记忆网络 (LSTM)

server/2025/2/5 20:00:27/

 人工智能例子汇总:AI常见的算法和例子-CSDN博客 

LSTM 通过 记忆单元(cell)三个门控机制(遗忘门、输入门、输出门)来控制信息流:

 记忆单元(Cell State)

  • 负责存储长期信息,并通过门控机制决定保留或丢弃信息。

 遗忘门(Forget Gate, ftf_tft​)

 输入门(Input Gate, iti_tit​)

 输出门(Output Gate, oto_tot​)

特性

传统 RNNLSTM
记忆能力短期记忆长短期记忆
计算复杂度
解决梯度消失
适用场景短序列数据长序列数据

LSTM 应用场景

  • 自然语言处理(NLP):文本生成、情感分析、机器翻译
  • 时间序列预测:股票预测、天气预报、传感器数据分析
  • 语音识别:自动字幕生成、语音转文字(ASR)
  • 机器人与控制系统:智能体决策、自动驾驶

例子:

下面例子实现了一个 基于 LSTM 的强化学习智能体,在 1D 网格环境 里移动,并找到最优路径。
最终,我们 绘制 5 条测试路径,并高亮显示最佳路径(红色)

import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt# ========== 1. 定义 LSTM 策略网络 ==========
class LSTMPolicy(nn.Module):def __init__(self, input_size, hidden_size, output_size, num_layers=1):super(LSTMPolicy, self).__init__()self.hidden_size = hidden_sizeself.num_layers = num_layersself.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True)self.fc = nn.Linear(hidden_size, output_size)self.softmax = nn.Softmax(dim=-1)def forward(self, x, hidden_state):batch_size = x.size(0)# 确保 hidden_state 维度正确if hidden_state[0].dim() == 2:hidden_state = (hidden_state[0].unsqueeze(1).repeat(1, batch_size, 1),hidden_state[1].unsqueeze(1).repeat(1, batch_size, 1))out, hidden_state = self.lstm(x, hidden_state)out = self.fc(out[:, -1, :])  # 取最后时间步的输出action_prob = self.softmax(out)  # 归一化输出,作为策略return action_prob, hidden_statedef init_hidden(self, batch_size=1):return (torch.zeros(self.num_layers, batch_size, self.hidden_size),torch.zeros(self.num_layers, batch_size, self.hidden_size))# ========== 2. 创建网格环境 ==========
class GridWorld:def __init__(self, grid_size=10, goal_position=9):self.grid_size = grid_sizeself.goal_position = goal_positionself.reset()def reset(self):self.position = 0return self.positiondef step(self, action):if action == 0:self.position = max(0, self.position - 1)elif action == 1:self.position = min(self.grid_size - 1, self.position + 1)reward = 1 if self.position == self.goal_position else -0.1done = self.position == self.goal_positionreturn self.position, reward, done# ========== 3. 训练智能体 ==========
def train(num_episodes=500, max_steps=50):env = GridWorld()input_size = 1hidden_size = 64output_size = 2num_layers = 1policy = LSTMPolicy(input_size, hidden_size, output_size, num_layers)optimizer = optim.Adam(policy.parameters(), lr=0.01)gamma = 0.99for episode in range(num_episodes):state = torch.tensor([[env.reset()]], dtype=torch.float32).unsqueeze(0)  # (1, 1, input_size)hidden_state = policy.init_hidden(batch_size=1)log_probs = []rewards = []for step in range(max_steps):action_probs, hidden_state = policy(state, hidden_state)action = torch.multinomial(action_probs, 1).item()log_prob = torch.log(action_probs.squeeze(0)[action])log_probs.append(log_prob)next_state, reward, done = env.step(action)rewards.append(reward)if done:breakstate = torch.tensor([[next_state]], dtype=torch.float32).unsqueeze(0)# 计算回报并更新策略returns = []R = 0for r in reversed(rewards):R = r + gamma * Rreturns.insert(0, R)returns = torch.tensor(returns, dtype=torch.float32)returns = (returns - returns.mean()) / (returns.std() + 1e-9)loss = sum([-log_prob * R for log_prob, R in zip(log_probs, returns)])optimizer.zero_grad()loss.backward()optimizer.step()if (episode + 1) % 50 == 0:print(f"Episode {episode + 1}/{num_episodes}, Total Reward: {sum(rewards)}")torch.save(policy.state_dict(), "policy.pth")# 训练智能体
train(500)# ========== 4. 测试智能体并绘制最佳路径 ==========
def test(num_episodes=5):env = GridWorld()input_size = 1hidden_size = 64output_size = 2num_layers = 1policy = LSTMPolicy(input_size, hidden_size, output_size, num_layers)policy.load_state_dict(torch.load("policy.pth"))plt.figure(figsize=(10, 5))best_path = Nonebest_steps = float('inf')for episode in range(num_episodes):state = torch.tensor([[env.reset()]], dtype=torch.float32).unsqueeze(0)  # (1, 1, input_size)hidden_state = policy.init_hidden(batch_size=1)positions = [env.position]  # 记录位置变化while True:action_probs, hidden_state = policy(state, hidden_state)action = torch.argmax(action_probs, dim=-1).item()next_state, reward, done = env.step(action)positions.append(next_state)if done:breakstate = torch.tensor([[next_state]], dtype=torch.float32).unsqueeze(0)# 记录最佳路径(最短步数)if len(positions) < best_steps:best_steps = len(positions)best_path = positions# 绘制普通路径(蓝色)plt.plot(range(len(positions)), positions, marker='o', linestyle='-', color='blue', alpha=0.6,label=f'Episode {episode + 1}' if episode == 0 else "")# 绘制最佳路径(红色)if best_path:plt.plot(range(len(best_path)), best_path, marker='o', linestyle='-', color='red', linewidth=2,label="Best Path")# 打印最佳路径print(f"Best Path (steps={best_steps}): {best_path}")plt.xlabel("Time Steps")plt.ylabel("Agent Position")plt.title("Agent's Movement Path (Best Path in Red)")plt.legend()plt.grid(True)plt.show()# 测试并绘制智能体移动路径
test(5)


http://www.ppmy.cn/server/165219.html

相关文章

Vue.js 使用组件库构建 UI

Vue.js 使用组件库构建 UI 在 Vue.js 项目中&#xff0c;构建漂亮又高效的用户界面&#xff08;UI&#xff09;是很重要的一环。组件库就是你开发 UI 的好帮手&#xff0c;它可以大大提高开发效率&#xff0c;减少重复工作&#xff0c;还能让你的项目更具一致性和专业感。今天…

面试--你的数据库中密码是如何存储的?

文章目录 三种分类使用 MD5 加密存储加盐存储Base64 编码:常见的对称加密算法常见的非对称加密算法https 传输加密 在开发中需要存储用户的密码&#xff0c;这个密码一定是加密存储的&#xff0c;如果是明文存储那么如果数据库被攻击了&#xff0c;密码就泄露了。 我们要对数据…

19C RAC在vmware虚拟机环境下的安装

RAC安装规划 ===IP== ORA19C01 public ip : 192.168.229.191 heatbeat : 192.168.0.1 vip : 192.168.229.193 ORA19C02 public ip :192.168.229.192 heatbeat : 192.168.0.2 vip : 192.168.229.194 scan ip 192.168.229.195 hosts: echo "192.168.229…

自学习记录-编程语言的特点(持续记录)

我学习的顺序是C -> python -> C -> Java。在讲到某项语言的特点是&#xff0c;可能会时不时穿插其他语言的特点。 Java 1 注解Annotation Python中也有类似的Decorators。以下为AI学习了解到的&#xff1a; Java的Annotation是一种元数据&#xff08;metadata)&a…

【华为OD机试】真题E卷-招聘(Java)

一、题目描述 题目描述: 某公司组织一场公开招聘活动,假设由于人数和场地的限制,每人每次面试的时长不等,并已经安排给定,用(S1,E1)、 (S2,E2)、 (Sj,Ej)…(Si < Ei,均为非负整数)表示每场面试的开始和结束时间。 面试采用一对一的方式,即一名面试官同时只能面试一名…

ROS应用之SwarmSim在ROS 中的协同路径规划

SwarmSim 在 ROS 中的协同路径规划 前言 在多机器人系统&#xff08;Multi-Robot Systems, MRS&#xff09;中&#xff0c;SwarmSim 是一个常用的模拟工具&#xff0c;可以对多机器人进行仿真以实现复杂任务的协同。除了任务分配逻辑以外&#xff0c;SwarmSim 在协同路径规划方…

攻防世界 simple_php

&#xfeff;<?php show_source(__FILE__);//显示 PHP 文件的源码 include("config.php");// 包含了一个config.php文件 $a$_GET[a];//获取GET 参数 a 和 b $b$_GET[b]; if($a0 and $a) {echo $flag1; } if(is_numeric($b))//防止数字输入 {exit(); } if($b>1…

【ChatGPT:开启人工智能新纪元】

一、ChatGPT 是什么 最近,ChatGPT 可是火得一塌糊涂,不管是在科技圈、媒体界,还是咱们普通人的日常聊天里,都能听到它的大名。好多人都在讨论,这 ChatGPT 到底是个啥 “神器”,能让大家这么着迷?今天咱就好好唠唠。 ChatGPT,全称是 Chat Generative Pre-trained Trans…