介绍
PPO(Proximal Policy Optimization,近端策略优化)是一种用于强化学习的策略优化算法,由OpenAI在2017年提出。PPO结合了策略梯度方法的优点和信任区域优化(Trust Region Optimization)的思想,旨在实现高效、稳定的策略优化。它已成为强化学习中最常用的算法之一,广泛应用于各种任务,如游戏、机器人控制和自然语言处理等。
PPO的核心目标是通过限制策略更新的幅度,确保每次更新后的策略不会与之前的策略偏离太远,从而避免训练过程中的不稳定性和崩溃。具体来说,PPO通过引入一个“剪裁”(clipping)机制,限制策略更新的幅度,使其在一个安全的范围内进行。
PPO基于策略梯度方法,其目标函数可以表示为:
其中:是新旧策略的概率比。
是优势函数,表示当前动作相对于平均表现的优劣。
是一个超参数,用于控制剪裁的范围(通常取值为0.1到0.2)。 剪裁机制的作用是:当
超出
范围时,目标函数会被限制,从而避免过大的策略更新。
代码
1. 导入所需要的库
python">import torch
import torch.nn as nn
from torch.distributions import MultivariateNormal
from torch.distributions import Categorical
2. 定义设备
python">print("============================================================================================")
# 设置设备为 cpu 或 cuda
device = torch.device('cpu')
if torch.cuda.is_available():device = torch.device('cuda:0')torch.cuda.empty_cache()print("设备设置为 : " + str(torch.cuda.get_device_name(device)))
else:print("设备设置为 : cpu")
print("============================================================================================")
3. 经验回放缓冲区
python"># 经验回放缓冲区
class RolloutBuffer:def __init__(self):self.actions = [] # 存储动作self.states = [] # 存储状态self.logprobs = [] # 存储对数概率self.rewards = [] # 存储奖励self.state_values = [] # 存储状态值self.is_terminals = [] # 存储是否终止标记def clear(self):# 清空所有缓存数据del self.actions[:]del self.states[:]del self.logprobs[:]del self.rewards[:]del self.state_values[:]del self.is_terminals[:]
4. Actor-Critic 网络
python"># Actor-Critic 网络
class ActorCritic(nn.Module):def __init__(self, state_dim, action_dim, has_continuous_action_space, action_std_init):super(ActorCritic, self).__init__()self.has_continuous_action_space = has_continuous_action_space# 如果是连续动作空间,则初始化动作方差if has_continuous_action_space:self.action_dim = action_dimself.action_var = torch.full((action_dim,), action_std_init * action_std_init).to(device)# 定义 actor 网络if has_continuous_action_space:self.actor = nn.Sequential(nn.Linear(state_dim, 64),nn.Tanh(),nn.Linear(64, 64),nn.Tanh(),nn.Linear(64, action_dim),nn.Tanh())else:self.actor = nn.Sequential(nn.Linear(state_dim, 64),nn.Tanh(),nn.Linear(64, 64),nn.Tanh(),nn.Linear(64, action_dim),nn.Softmax(dim=-1))# 定义 critic 网络self.critic = nn.Sequential(nn.Linear(state_dim, 64),nn.Tanh(),nn.Linear(64, 64),nn.Tanh(),nn.Linear(64, 1))# 设置动作标准差def set_action_std(self, new_action_std):# 如果是连续动作空间: 更新 self.action_var ,计算新的动作方差# 如果是离散动作空间: 打印警告信息,提示该方法不适用于离散动作空间if self.has_continuous_action_space:self.action_var = torch.full((self.action_dim,), new_action_std * new_action_std).to(device)# 创建一个形状为 (action_dim,) 的张量,并用 action_std_init * action_std_init 填充所有元素else:print("--------------------------------------------------------------------------------------------")print("警告:在离散动作空间策略上调用 ActorCritic::set_action_std()")print("--------------------------------------------------------------------------------------------")# forward 方法未实现,直接抛出 NotImplementedError 异常# ActorCritic 类的主要功能通过 act 和 evaluate 方法实现,而不是 forwarddef forward(self):raise NotImplementedErrordef act(self, state):# 根据当前状态选择动作并返回动作、动作对数概率和状态值if self.has_continuous_action_space:action_mean = self.actor(state) # 通过 Actor 网络计算动作的均值cov_mat = torch.diag(self.action_var).unsqueeze(dim=0) # 构建协方差矩阵,使用 torch.diag 将对角矩阵扩展为合适的形状dist = MultivariateNormal(action_mean, cov_mat) # 用于生成动作else:action_probs = self.actor(state) # 通过 Actor 网络计算动作的概率分布dist = Categorical(action_probs) # 用于生成动作action = dist.sample() # 从分布中采样一个动作action_logprob = dist.log_prob(action) # 计算动作的对数概率state_val = self.critic(state) # 通过 Critic 网络评估状态值# 返回动作、动作对数概率和状态值,并调用detach()方法断开计算图return action.detach(), action_logprob.detach(), state_val.detach()def evaluate(self, state, action):# 评估给定状态和动作下的动作对数概率、状态值和分布熵if self.has_continuous_action_space:action_mean = self.actor(state)action_var = self.action_var.expand_as(action_mean)cov_mat = torch.diag_embed(action_var).to(device)dist = MultivariateNormal(action_mean, cov_mat)# 针对单一动作环境进行调整if self.action_dim == 1:action = action.reshape(-1, self.action_dim)else:action_probs = self.actor(state)dist = Categorical(action_probs)action_logprobs = dist.log_prob(action)dist_entropy = dist.entropy()state_values = self.critic(state)return action_logprobs, state_values, dist_entropy
为什么需要两个函数?
- act 函数 :用于实际与环境交互,生成的动作需要与环境交互,因此不需要计算梯度。
- evaluate 函数 :用于策略更新,需要计算梯度以优化网络参数。
5. PPO算法
python"># PPO 算法
class PPO:def __init__(self, state_dim, action_dim, lr_actor, lr_critic, gamma, K_epochs, eps_clip, has_continuous_action_space, action_std_init=0.6):# 初始化参数self.has_continuous_action_space = has_continuous_action_spaceif has_continuous_action_space:self.action_std = action_std_initself.gamma = gammaself.eps_clip = eps_clipself.K_epochs = K_epochsself.buffer = RolloutBuffer()# 初始化当前策略网络和优化器self.policy = ActorCritic(state_dim, action_dim, has_continuous_action_space, action_std_init).to(device)self.optimizer = torch.optim.Adam([{'params': self.policy.actor.parameters(), 'lr': lr_actor},{'params': self.policy.critic.parameters(), 'lr': lr_critic}])# 初始化旧策略网络,并复制当前策略的参数self.policy_old = ActorCritic(state_dim, action_dim, has_continuous_action_space, action_std_init).to(device)self.policy_old.load_state_dict(self.policy.state_dict())# 初始化损失函数self.MseLoss = nn.MSELoss()# 设置动作标准差def set_action_std(self, new_action_std):if self.has_continuous_action_space:self.action_std = new_action_stdself.policy.set_action_std(new_action_std)self.policy_old.set_action_std(new_action_std)else:print("--------------------------------------------------------------------------------------------")print("警告:在离散动作空间策略上调用 PPO::set_action_std()")print("--------------------------------------------------------------------------------------------")# 衰减动作标准差def decay_action_std(self, action_std_decay_rate, min_action_std):print("--------------------------------------------------------------------------------------------")if self.has_continuous_action_space:self.action_std = self.action_std - action_std_decay_rateself.action_std = round(self.action_std, 4)if self.action_std <= min_action_std:self.action_std = min_action_stdprint("将 actor 输出的 action_std 设置为最小值 : ", self.action_std)else:print("将 actor 输出的 action_std 设置为 : ", self.action_std)self.set_action_std(self.action_std)else:print("警告:在离散动作空间策略上调用 PPO::decay_action_std()")print("--------------------------------------------------------------------------------------------")# 根据当前状态选择动作,并将数据存入缓冲区def select_action(self, state):if self.has_continuous_action_space:with torch.no_grad():state = torch.FloatTensor(state).to(device)action, action_logprob, state_val = self.policy_old.act(state)self.buffer.states.append(state)self.buffer.actions.append(action)self.buffer.logprobs.append(action_logprob)self.buffer.state_values.append(state_val)return action.detach().cpu().numpy().flatten()else:with torch.no_grad():state = torch.FloatTensor(state).to(device)action, action_logprob, state_val = self.policy_old.act(state)self.buffer.states.append(state)self.buffer.actions.append(action)self.buffer.logprobs.append(action_logprob)self.buffer.state_values.append(state_val)return action.item()# 更新策略def update(self):# 使用蒙特卡洛方法估计回报rewards = []discounted_reward = 0for reward, is_terminal in zip(reversed(self.buffer.rewards), reversed(self.buffer.is_terminals)):if is_terminal:discounted_reward = 0discounted_reward = reward + (self.gamma * discounted_reward)rewards.insert(0, discounted_reward)# 对回报进行归一化处理rewards = torch.tensor(rewards, dtype=torch.float32).to(device)rewards = (rewards - rewards.mean()) / (rewards.std() + 1e-7)# 将列表转换为张量old_states = torch.squeeze(torch.stack(self.buffer.states, dim=0)).detach().to(device)old_actions = torch.squeeze(torch.stack(self.buffer.actions, dim=0)).detach().to(device)old_logprobs = torch.squeeze(torch.stack(self.buffer.logprobs, dim=0)).detach().to(device)old_state_values = torch.squeeze(torch.stack(self.buffer.state_values, dim=0)).detach().to(device)# 计算优势值advantages = rewards.detach() - old_state_values.detach()# 优化策略,进行 K 个 epoch 的训练for _ in range(self.K_epochs):# 评估旧策略下的动作和状态值logprobs, state_values, dist_entropy = self.policy.evaluate(old_states, old_actions)state_values = torch.squeeze(state_values)# 计算概率比率 (pi_theta / pi_theta_old)ratios = torch.exp(logprobs - old_logprobs.detach())# 计算代理损失surr1 = ratios * advantagessurr2 = torch.clamp(ratios, 1 - self.eps_clip, 1 + self.eps_clip) * advantages# PPO 剪切目标的最终损失loss = -torch.min(surr1, surr2) + 0.5 * self.MseLoss(state_values, rewards) - 0.01 * dist_entropy# 反向传播并更新梯度self.optimizer.zero_grad()loss.mean().backward()self.optimizer.step()# 将当前策略的参数复制给旧策略self.policy_old.load_state_dict(self.policy.state_dict())# 清空缓冲区self.buffer.clear()def save(self, checkpoint_path):# 保存模型参数到指定路径torch.save(self.policy_old.state_dict(), checkpoint_path)def load(self, checkpoint_path):# 从指定路径加载模型参数self.policy_old.load_state_dict(torch.load(checkpoint_path, map_location=lambda storage, loc: storage))self.policy.load_state_dict(torch.load(checkpoint_path, map_location=lambda storage, loc: storage))