强化学习策略梯度算法实现文档(CartPole-v1)

server/2025/3/3 14:36:24/

1. 概述

本代码使用策略梯度方法(Policy Gradient)解决OpenAI Gym的CartPole-v1环境问题,包含以下核心组件:

  • 策略网络:神经网络输出动作概率分布

  • REINFORCE算法:带熵正则化的策略梯度方法

  • 训练监控:实时奖励跟踪与模型保存

  • 可视化:训练过程曲线与策略演示


2. 环境说明

python

复制

env = gym.make('CartPole-v1')
  • 状态空间:4维连续向量 [车位置, 车速, 杆角度, 杆角速度]

  • 动作空间:2个离散动作(左推/右推)

  • 奖励机制:每步存活奖励+1,最大步长500

  • 终止条件

    • 杆倾斜超过15度

    • 车移动超出±2.4单位

    • 连续存活500步(成功)


3. 策略网络架构

python

复制

class PolicyNetwork(nn.Module):def __init__(self, input_dim, output_dim):super().__init__()self.fc = nn.Sequential(nn.Linear(input_dim, 128),nn.ReLU(),nn.Linear(128, 128),nn.ReLU(),nn.Linear(128, output_dim))
  • 输入层:4个神经元(与环境状态维度一致)

  • 隐藏层:2个全连接层(128神经元),ReLU激活

  • 输出层:2个神经元(对应动作数),输出logits

  • 设计特点

    • 无最终softmax层(由Categorical分布自动处理)

    • 深度结构增强表征能力


4. 训练算法实现
4.1 核心参数
参数作用
gamma0.99未来奖励折扣因子
entropy_coef0.01熵正则化系数
lr1e-3Adam优化器学习率
max_norm0.5梯度裁剪阈值
4.2 训练流程

python

复制

def train(...):# 数据收集阶段while not done:prob_dist = Categorical(logits=policy_net(state_tensor))action = prob_dist.sample()# 存储log_prob, entropy, reward等# 回报计算returns = (returns - returns.mean()) / (returns.std() + 1e-8)# 损失计算policy_loss = -log_prob * R - entropy_coef * entropy# 梯度更新total_loss.backward()torch.nn.utils.clip_grad_norm_(...)optimizer.step()
  1. 经验收集

    • 使用当前策略采样轨迹

    • 记录状态、动作、奖励、对数概率、熵值

  2. 回报计算

    • 折扣累计奖励:Rt=∑k=0Tγkrt+kRt​=∑k=0T​γkrt+k​

    • 标准化处理:R~t=(Rt−μR)/σRR~t​=(Rt​−μR​)/σR​

  3. 损失函数

    • 策略梯度损失:LPG=−E[log⁡π(a∣s)R~]LPG​=−E[logπ(a∣s)R~]

    • 熵正则项:Lent=−βH(π(⋅∣s))Lent​=−βH(π(⋅∣s))

    • 总损失:Ltotal=LPG+LentLtotal​=LPG​+Lent​

  4. 优化步骤

    • 梯度裁剪防止爆炸

    • Adam优化器更新参数


5. 关键技术点
5.1 熵正则化

python

复制

entropy = prob_dist.entropy()
policy_loss.append(... - entropy_coef * entropy)
  • 作用:增加探索,防止策略过早收敛

  • 效果:保持动作概率分布分散度

5.2 梯度裁剪

python

复制

torch.nn.utils.clip_grad_norm_(..., max_norm=0.5)
  • 原理:限制梯度L2范数不超过阈值

  • 优势:提升训练稳定性

5.3 状态标准化

python

复制

returns = (returns - returns.mean()) / (...)
  • 目的:减少回报方差

  • 注意:保留少量常数(1e-8)防止除零错误


6. 训练监控与评估
6.1 进度跟踪

python

复制

if (episode + 1) % 50 == 0:avg_reward = np.mean(episode_rewards[-50:])if avg_reward >= env.spec.reward_threshold:  # 默认阈值475print(f"Solved in {episode + 1} episodes!")
  • 输出频率:每50轮显示平均奖励

  • 停止条件:最近50轮平均奖励≥475

6.2 模型保存

python

复制

torch.save(policy_net.state_dict(), 'cartpole_policy.pth')
  • 格式:PyTorch模型参数

  • 用途:后续部署或继续训练

6.3 策略测试

python

复制

def test_policy(...):action = policy_net(...).argmax().item()env.render()
  • 策略选择:贪婪策略(取最大概率动作)

  • 渲染显示:可视化杆平衡过程


7. 可视化输出

python

复制

plt.plot(rewards)
plt.title('CartPole Training Progress')
  • X轴:训练轮次

  • Y轴:单轮总奖励

  • 典型曲线

    训练曲线示例


8. 运行与调优
8.1 执行命令

bash

复制

python cartpole_pg.py
8.2 预期输出

text

复制

Episode 50, Avg Reward (last 50): 42.3
Episode 100, Avg Reward (last 50): 195.2
Solved in 127 episodes!
8.3 调优建议
  • 学习率:尝试1e-4 ~ 3e-3范围

  • 网络结构:调整隐藏层维度(64-256)

  • 熵系数:0.001-0.1之间调节

  • 折扣因子:0.95-0.999


9. 扩展应用
  • 更换环境:适配MountainCar、LunarLander等离散动作环境

  • 算法改进

    • 添加基线(Baseline)减少方差

    • 实现PPO/TRPO等高级策略梯度方法

  • 分布式训练:使用多环境并行采样

此实现完整展示了策略梯度方法的核心思想,可作为强化学习基础实验平台。

完整代码

python">import gym
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from torch.distributions import Categorical
import matplotlib.pyplot as plt  # 添加导入语句# 定义策略网络(增加层数和激活函数)
class PolicyNetwork(nn.Module):def __init__(self, input_dim, output_dim):super(PolicyNetwork, self).__init__()self.fc = nn.Sequential(nn.Linear(input_dim, 128),nn.ReLU(),nn.Linear(128, 128),nn.ReLU(),nn.Linear(128, output_dim))def forward(self, x):return self.fc(x)# 改进的训练函数(修复梯度计算,添加熵正则化)
def train(env, policy_net, optimizer, num_episodes=1500, gamma=0.99, entropy_coef=0.01):episode_rewards = []for episode in range(num_episodes):state, _ = env.reset()  # 适配新版gym APIstates, actions, rewards, log_probs, entropies = [], [], [], [], []done = False# 收集轨迹数据while not done:state_tensor = torch.FloatTensor(state)logits = policy_net(state_tensor)prob_dist = Categorical(logits=logits)action = prob_dist.sample()log_prob = prob_dist.log_prob(action)entropy = prob_dist.entropy()# 执行动作(适配新版gym API)next_state, reward, terminated, truncated, _ = env.step(action.item())done = terminated or truncated# 存储数据states.append(state_tensor)actions.append(action)rewards.append(reward)log_probs.append(log_prob)entropies.append(entropy)state = next_state# 计算折扣回报returns = []R = 0for r in reversed(rewards):R = r + gamma * Rreturns.insert(0, R)returns = torch.tensor(returns)# 标准化回报returns = (returns - returns.mean()) / (returns.std() + 1e-8)# 计算损失policy_loss = []for log_prob, R, entropy in zip(log_probs, returns, entropies):policy_loss.append(-log_prob * R - entropy_coef * entropy)total_loss = torch.stack(policy_loss).sum()# 反向传播optimizer.zero_grad()total_loss.backward()# 梯度裁剪(防止梯度爆炸)torch.nn.utils.clip_grad_norm_(policy_net.parameters(), max_norm=0.5)optimizer.step()# 记录训练进度total_reward = sum(rewards)episode_rewards.append(total_reward)# 显示训练进度if (episode + 1) % 50 == 0:avg_reward = np.mean(episode_rewards[-50:])print(f"Episode {episode + 1}, Avg Reward (last 50): {avg_reward:.1f}")if avg_reward >= env.spec.reward_threshold:print(f"Solved in {episode + 1} episodes!")breakreturn episode_rewards# 主函数(添加模型保存和测试功能)
if __name__ == "__main__":# 创建环境env = gym.make('CartPole-v1')state_dim = env.observation_space.shape[0]action_dim = env.action_space.n# 初始化网络和优化器policy_net = PolicyNetwork(state_dim, action_dim)optimizer = optim.Adam(policy_net.parameters(), lr=1e-3)# 训练模型rewards = train(env, policy_net, optimizer, num_episodes=1000)# 保存模型torch.save(policy_net.state_dict(), 'cartpole_policy.pth')# 测试训练好的策略def test_policy(env, policy_net, episodes=10):for _ in range(episodes):state, _ = env.reset()done = Falsewhile not done:with torch.no_grad():action = policy_net(torch.FloatTensor(state)).argmax().item()state, reward, terminated, truncated, _ = env.step(action)done = terminated or truncatedenv.render()print(f"Test episode finished")env.close()test_policy(env, policy_net)# 绘制训练曲线plt.plot(rewards)plt.xlabel('Episode')plt.ylabel('Total Reward')plt.title('CartPole Training Progress')plt.show()


http://www.ppmy.cn/server/172093.html

相关文章

Spring Boot 与 MyBatis 数据库操作

一、核心原理 Spring Boot 的自动配置 通过 mybatis-spring-boot-starter 自动配置 DataSource(连接池)、SqlSessionFactory 和 SqlSessionTemplate。 扫描 Mapper 接口或指定包路径,生成动态代理实现类。 MyBatis 的核心组件 SqlSessionF…

【实战 ES】实战 Elasticsearch:快速上手与深度实践-2.1.1动态映射(Dynamic Mapping)的合理控制

👉 点击关注不迷路 👉 点击关注不迷路 👉 点击关注不迷路 文章大纲 Elasticsearch动态映射的合理控制与最佳实践1. 动态映射核心原理1.1 动态映射工作机制1.2 核心处理流程 2. 动态映射配置策略2.1 动态模式对照表2.2 配置示例 3. 字段类型自…

java容器 LIst、set、Map

Java容器中的List、Set、Map是核心数据结构,各自适用于不同的场景 一、List(有序、可重复) List接口代表有序集合,允许元素重复和通过索引访问,主要实现类包括: ArrayList 底层结构:动态数组…

数据集笔记:NUSMods API

1 介绍 NUSMods API 包含用于渲染 NUSMods 的数据。这些数据包括新加坡国立大学(NUS)提供的课程以及课程表的信息,还包括上课地点的详细信息。 可以使用并实验这些数据,它们是从教务处提供的官方 API 中提取的。 该 API 由静态的…

【AI实践】xiaozhi-esp32虾哥开源版-分析

语音交互总流程 客户端(ESP32) 服务器 | | | 本地唤醒词检测"小智" | | | | 打开音频通道 | |------------------------>| | | | 发送唤醒词音频 | |------------------------>| | | | 发送唤醒事件 | |------------------------>| | {"type":&qu…

基于SQL数据库的酒店管理系统

一、数据库设计 1.需求分析 客房的预定:可以通过网络进行预定,预定修改,取消预订。 客房管理:预定管理、客房查询、设置房态、开房、换房、续住、退房等管理。 员工管理: 员工修改信息、人员调配。 账务管理&…

windows安装vue

1、下载nodejs安装包 https://nodejs.cn/download/ 2、安装node 中途记得可以自己改安装路径,其他都是下一步 3、安装完成后检查 node -v :查看nodejs的版本 npm -v :查看npm的版本 4、修改npm默认安装目录与缓存日志目录的位置 在nodejs目…

LangPrompt提示词

LangPrompt提示词 https://github.com/langgptai/LangGPT 学习LangGPT的仓库,帮我创建 一个专门生成LangGPT格式prompt的助手 根据LangGPT的格式规范设计的专业提示词生成助手框架。以下是分步骤的解决方案: 助手角色定义模板 # Role: LangGPT提示词架…