【强化学习】PPO算法代码详解

news/2025/3/16 21:20:38/

介绍

PPO(Proximal Policy Optimization,近端策略优化)是一种用于强化学习的策略优化算法,由OpenAI在2017年提出。PPO结合了策略梯度方法的优点和信任区域优化(Trust Region Optimization)的思想,旨在实现高效、稳定的策略优化。它已成为强化学习中最常用的算法之一,广泛应用于各种任务,如游戏、机器人控制和自然语言处理等。

PPO的核心目标是通过限制策略更新的幅度,确保每次更新后的策略不会与之前的策略偏离太远,从而避免训练过程中的不稳定性和崩溃。具体来说,PPO通过引入一个“剪裁”(clipping)机制,限制策略更新的幅度,使其在一个安全的范围内进行。

PPO基于策略梯度方法,其目标函数可以表示为: 

L^{CLIP}(\theta) = \mathbb{E}_t \left[ \min \left( r_t(\theta) \cdot A_t, \text{clip}(r_t(\theta), 1-\epsilon, 1+\epsilon) \cdot A_t \right) \right]

其中:r_t(\theta) = \frac{\pi_\theta(a_t|s_t)}{\pi_{\theta_{\text{old}}}(a_t|s_t)}是新旧策略的概率比。 A_t是优势函数,表示当前动作相对于平均表现的优劣。  \epsilon 是一个超参数,用于控制剪裁的范围(通常取值为0.1到0.2)。 剪裁机制的作用是:当 r_t(\theta) 超出 [1-\epsilon, 1+\epsilon] 范围时,目标函数会被限制,从而避免过大的策略更新。

代码

1. 导入所需要的库

python">import torch
import torch.nn as nn
from torch.distributions import MultivariateNormal
from torch.distributions import Categorical

2. 定义设备

python">print("============================================================================================")
# 设置设备为 cpu 或 cuda
device = torch.device('cpu')
if torch.cuda.is_available():device = torch.device('cuda:0')torch.cuda.empty_cache()print("设备设置为 : " + str(torch.cuda.get_device_name(device)))
else:print("设备设置为 : cpu")
print("============================================================================================")

3. 经验回放缓冲区

python"># 经验回放缓冲区
class RolloutBuffer:def __init__(self):self.actions = []         # 存储动作self.states = []          # 存储状态self.logprobs = []        # 存储对数概率self.rewards = []         # 存储奖励self.state_values = []    # 存储状态值self.is_terminals = []    # 存储是否终止标记def clear(self):# 清空所有缓存数据del self.actions[:]del self.states[:]del self.logprobs[:]del self.rewards[:]del self.state_values[:]del self.is_terminals[:]

4. Actor-Critic 网络

python"># Actor-Critic 网络
class ActorCritic(nn.Module):def __init__(self, state_dim, action_dim, has_continuous_action_space, action_std_init):super(ActorCritic, self).__init__()self.has_continuous_action_space = has_continuous_action_space# 如果是连续动作空间,则初始化动作方差if has_continuous_action_space:self.action_dim = action_dimself.action_var = torch.full((action_dim,), action_std_init * action_std_init).to(device)# 定义 actor 网络if has_continuous_action_space:self.actor = nn.Sequential(nn.Linear(state_dim, 64),nn.Tanh(),nn.Linear(64, 64),nn.Tanh(),nn.Linear(64, action_dim),nn.Tanh())else:self.actor = nn.Sequential(nn.Linear(state_dim, 64),nn.Tanh(),nn.Linear(64, 64),nn.Tanh(),nn.Linear(64, action_dim),nn.Softmax(dim=-1))# 定义 critic 网络self.critic = nn.Sequential(nn.Linear(state_dim, 64),nn.Tanh(),nn.Linear(64, 64),nn.Tanh(),nn.Linear(64, 1))# 设置动作标准差def set_action_std(self, new_action_std):# 如果是连续动作空间: 更新 self.action_var ,计算新的动作方差# 如果是离散动作空间: 打印警告信息,提示该方法不适用于离散动作空间if self.has_continuous_action_space:self.action_var = torch.full((self.action_dim,), new_action_std * new_action_std).to(device)# 创建一个形状为 (action_dim,) 的张量,并用 action_std_init * action_std_init 填充所有元素else:print("--------------------------------------------------------------------------------------------")print("警告:在离散动作空间策略上调用 ActorCritic::set_action_std()")print("--------------------------------------------------------------------------------------------")# forward 方法未实现,直接抛出 NotImplementedError 异常# ActorCritic 类的主要功能通过 act 和 evaluate 方法实现,而不是 forwarddef forward(self):raise NotImplementedErrordef act(self, state):# 根据当前状态选择动作并返回动作、动作对数概率和状态值if self.has_continuous_action_space:action_mean = self.actor(state) # 通过 Actor 网络计算动作的均值cov_mat = torch.diag(self.action_var).unsqueeze(dim=0) # 构建协方差矩阵,使用 torch.diag 将对角矩阵扩展为合适的形状dist = MultivariateNormal(action_mean, cov_mat) # 用于生成动作else:action_probs = self.actor(state) # 通过 Actor 网络计算动作的概率分布dist = Categorical(action_probs) # 用于生成动作action = dist.sample() # 从分布中采样一个动作action_logprob = dist.log_prob(action) # 计算动作的对数概率state_val = self.critic(state) # 通过 Critic 网络评估状态值# 返回动作、动作对数概率和状态值,并调用detach()方法断开计算图return action.detach(), action_logprob.detach(), state_val.detach()def evaluate(self, state, action):# 评估给定状态和动作下的动作对数概率、状态值和分布熵if self.has_continuous_action_space:action_mean = self.actor(state)action_var = self.action_var.expand_as(action_mean)cov_mat = torch.diag_embed(action_var).to(device)dist = MultivariateNormal(action_mean, cov_mat)# 针对单一动作环境进行调整if self.action_dim == 1:action = action.reshape(-1, self.action_dim)else:action_probs = self.actor(state)dist = Categorical(action_probs)action_logprobs = dist.log_prob(action)dist_entropy = dist.entropy()state_values = self.critic(state)return action_logprobs, state_values, dist_entropy

​​​​​​​​​​​​​​为什么需要两个函数?

  • act 函数 :用于实际与环境交互,生成的动作需要与环境交互,因此不需要计算梯度。
  • evaluate 函数 :用于策略更新,需要计算梯度以优化网络参数。

5. PPO算法

python"># PPO 算法
class PPO:def __init__(self, state_dim, action_dim, lr_actor, lr_critic, gamma, K_epochs, eps_clip, has_continuous_action_space, action_std_init=0.6):# 初始化参数self.has_continuous_action_space = has_continuous_action_spaceif has_continuous_action_space:self.action_std = action_std_initself.gamma = gammaself.eps_clip = eps_clipself.K_epochs = K_epochsself.buffer = RolloutBuffer()# 初始化当前策略网络和优化器self.policy = ActorCritic(state_dim, action_dim, has_continuous_action_space, action_std_init).to(device)self.optimizer = torch.optim.Adam([{'params': self.policy.actor.parameters(), 'lr': lr_actor},{'params': self.policy.critic.parameters(), 'lr': lr_critic}])# 初始化旧策略网络,并复制当前策略的参数self.policy_old = ActorCritic(state_dim, action_dim, has_continuous_action_space, action_std_init).to(device)self.policy_old.load_state_dict(self.policy.state_dict())# 初始化损失函数self.MseLoss = nn.MSELoss()# 设置动作标准差def set_action_std(self, new_action_std):if self.has_continuous_action_space:self.action_std = new_action_stdself.policy.set_action_std(new_action_std)self.policy_old.set_action_std(new_action_std)else:print("--------------------------------------------------------------------------------------------")print("警告:在离散动作空间策略上调用 PPO::set_action_std()")print("--------------------------------------------------------------------------------------------")# 衰减动作标准差def decay_action_std(self, action_std_decay_rate, min_action_std):print("--------------------------------------------------------------------------------------------")if self.has_continuous_action_space:self.action_std = self.action_std - action_std_decay_rateself.action_std = round(self.action_std, 4)if self.action_std <= min_action_std:self.action_std = min_action_stdprint("将 actor 输出的 action_std 设置为最小值 : ", self.action_std)else:print("将 actor 输出的 action_std 设置为 : ", self.action_std)self.set_action_std(self.action_std)else:print("警告:在离散动作空间策略上调用 PPO::decay_action_std()")print("--------------------------------------------------------------------------------------------")# 根据当前状态选择动作,并将数据存入缓冲区def select_action(self, state):if self.has_continuous_action_space:with torch.no_grad():state = torch.FloatTensor(state).to(device)action, action_logprob, state_val = self.policy_old.act(state)self.buffer.states.append(state)self.buffer.actions.append(action)self.buffer.logprobs.append(action_logprob)self.buffer.state_values.append(state_val)return action.detach().cpu().numpy().flatten()else:with torch.no_grad():state = torch.FloatTensor(state).to(device)action, action_logprob, state_val = self.policy_old.act(state)self.buffer.states.append(state)self.buffer.actions.append(action)self.buffer.logprobs.append(action_logprob)self.buffer.state_values.append(state_val)return action.item()# 更新策略def update(self):# 使用蒙特卡洛方法估计回报rewards = []discounted_reward = 0for reward, is_terminal in zip(reversed(self.buffer.rewards), reversed(self.buffer.is_terminals)):if is_terminal:discounted_reward = 0discounted_reward = reward + (self.gamma * discounted_reward)rewards.insert(0, discounted_reward)# 对回报进行归一化处理rewards = torch.tensor(rewards, dtype=torch.float32).to(device)rewards = (rewards - rewards.mean()) / (rewards.std() + 1e-7)# 将列表转换为张量old_states = torch.squeeze(torch.stack(self.buffer.states, dim=0)).detach().to(device)old_actions = torch.squeeze(torch.stack(self.buffer.actions, dim=0)).detach().to(device)old_logprobs = torch.squeeze(torch.stack(self.buffer.logprobs, dim=0)).detach().to(device)old_state_values = torch.squeeze(torch.stack(self.buffer.state_values, dim=0)).detach().to(device)# 计算优势值advantages = rewards.detach() - old_state_values.detach()# 优化策略,进行 K 个 epoch 的训练for _ in range(self.K_epochs):# 评估旧策略下的动作和状态值logprobs, state_values, dist_entropy = self.policy.evaluate(old_states, old_actions)state_values = torch.squeeze(state_values)# 计算概率比率 (pi_theta / pi_theta_old)ratios = torch.exp(logprobs - old_logprobs.detach())# 计算代理损失surr1 = ratios * advantagessurr2 = torch.clamp(ratios, 1 - self.eps_clip, 1 + self.eps_clip) * advantages# PPO 剪切目标的最终损失loss = -torch.min(surr1, surr2) + 0.5 * self.MseLoss(state_values, rewards) - 0.01 * dist_entropy# 反向传播并更新梯度self.optimizer.zero_grad()loss.mean().backward()self.optimizer.step()# 将当前策略的参数复制给旧策略self.policy_old.load_state_dict(self.policy.state_dict())# 清空缓冲区self.buffer.clear()def save(self, checkpoint_path):# 保存模型参数到指定路径torch.save(self.policy_old.state_dict(), checkpoint_path)def load(self, checkpoint_path):# 从指定路径加载模型参数self.policy_old.load_state_dict(torch.load(checkpoint_path, map_location=lambda storage, loc: storage))self.policy.load_state_dict(torch.load(checkpoint_path, map_location=lambda storage, loc: storage))


http://www.ppmy.cn/news/1579657.html

相关文章

微信小程序实现根据不同的用户角色显示不同的tabbar并且可以完整的切换tabbar

直接上图上代码吧 // login/login.js const app getApp() Page({/*** 页面的初始数据*/data: {},/*** 生命周期函数--监听页面加载*/onLoad(options) {},/*** 生命周期函数--监听页面初次渲染完成*/onReady() {},/*** 生命周期函数--监听页面显示*/onShow() {},/*** 生命周期函…

外呼系统破局电话管控:AI电销机器人合规运营实战指南

随着运营商对电话卡管控日趋严格&#xff0c;某金融科技公司曾因单日外呼超限导致80%号码被封——这一案例暴露出AI电销机器人在效率与合规间的矛盾。但数据显示&#xff0c;采用合规策略的企业外呼接通率仍能保持38%以上&#xff0c;关键在于建立适配监管环境的智能外呼体系。…

基于SSM + JSP 的水果蔬菜商城

基于ssm的水果蔬菜商城系统前台和后台&#xff08;源码安装视频数据库环境&#xff09;计算机项目程序设计管理系统java小程序网站商城 一.相关技术 Java、Spring、Springboot、MVC、Mybatis、MySQL、SSM框架、Web、HTML、maven、JavaScript、css、vue 二.部署配置 1.IntelliJ …

有效封装一个 WebSocket 供全局使用

前言 在现代 Web 应用中&#xff0c;实时通信已经成为越来越重要的一部分。而 WebSocket 技术的出现&#xff0c;使得实时通信变得更加高效和便捷。 WebSocket 协议是一种基于 TCP 协议的双向通信协议&#xff0c;它能够在客户端和服务器之间建立起持久性的连接&#xff0c;从…

【vue3学习笔记】(第144-146节)reactive函数;回顾vue2响应式原理;vue3响应式原理_proxy

尚硅谷Vue2.0Vue3.0全套教程丨vuejs从入门到精通 本篇内容对应课程第144-143节 课程 P144节 《reactive函数》笔记 验证 reactive 只能处理对象类型数据&#xff0c;不能处理基本类型数据&#xff1a;当使用reactive处理一个基本类型数据时&#xff0c;控制台直接报出了警告&a…

国家网络安全事件应急预案

目 录 1 总则 1.1 编制目的 1.2 编制依据 1.3 适用范围 1.4 事件分级 1.5 工作原则 2 组织机构与职责 2.1 领导机构与职责 2.2 办事机构与职责 2.3 各部门职责 2.4 各省&#xff08;区、市&#xff09;职责 3 监测与预警 3.1 预警分级 3.2 预警监测 3.3 预警研判…

附下载 | 2024 OWASP Top 10 基础设施安全风险.pdf

《2024 OWASP Top 10 基础设施安全风险》报告&#xff0c;由OWASP&#xff08;开放网络应用安全项目&#xff09;发布&#xff0c;旨在提升企业和组织对基础设施安全风险、威胁与漏洞的意识&#xff0c;并提供高质量的信息和最佳实践建议。报告列出了2024年最重要的10大基础设施…

Cesium零基础速成教程:一小时入门Cesium

一小时速成Cesium&#xff0c;掌握以下7个功能&#xff1a; 地图、图层、3D瓦片加载 Cesium空间数据Entity Cesium动态数据 地图事件(点击、移动) 相机 三维模型 粒子效果 Cesium教程配套笔记 部分内容 1、Cesium介绍 Cesium是使⽤JavaScript开发的基于WebGL的&#x…