代码
python">import gym
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import numpy as np
import pygame# 定义 Actor 网络
class Actor(nn.Module):def __init__(self, state_dim, action_dim, max_action):super(Actor, self).__init__()self.fc1 = nn.Linear(state_dim, 256)self.fc2 = nn.Linear(256, 256)self.mu = nn.Linear(256, action_dim)self.log_std = nn.Linear(256, action_dim)self.max_action = max_actiondef forward(self, state):x = F.relu(self.fc1(state))x = F.relu(self.fc2(x))mu = self.mu(x)log_std = self.log_std(x)log_std = torch.clamp(log_std, -20, 2)std = torch.exp(log_std)return mu, stddef sample(self, state):mu, std = self.forward(state)dist = torch.distributions.Normal(mu, std)action = dist.rsample()action = torch.tanh(action) * self.max_actionlog_prob = dist.log_prob(action).sum(axis=-1)log_prob -= (2 * (np.log(2) - action - F.softplus(-2 * action))).sum(axis=-1)return action, log_prob# 定义 Critic 网络
class Critic(nn.Module):def __init__(self, state_dim, action_dim):super(Critic, self).__init__()self.fc1 = nn.Linear(state_dim + action_dim, 256)self.fc2 = nn.Linear(256, 256)self.fc3 = nn.Linear(256, 1)def forward(self, state, action):x = torch.cat([state, action], 1)x = F.relu(self.fc1(x))x = F.relu(self.fc2(x))x = self.fc3(x)return x# SAC 算法
class SAC:def __init__(self, state_dim, action_dim, max_action):self.actor = Actor(state_dim, action_dim, max_action)self.critic1 = Critic(state_dim, action_dim)self.critic2 = Critic(state_dim, action_dim)self.target_critic1 = Critic(state_dim, action_dim)self.target_critic2 = Critic(state_dim, action_dim)self.target_critic1.load_state_dict(self.critic1.state_dict())self.target_critic2.load_state_dict(self.critic2.state_dict())self.actor_optimizer = optim.Adam(self.actor.parameters(), lr=3e-4)self.critic1_optimizer = optim.Adam(self.critic1.parameters(), lr=3e-4)self.critic2_optimizer = optim.Adam(self.critic2.parameters(), lr=3e-4)self.log_alpha = torch.tensor(np.log(0.1), requires_grad=True)self.alpha_optimizer = optim.Adam([self.log_alpha], lr=3e-4)self.gamma = 0.99self.tau = 0.005def select_action(self, state):state = torch.FloatTensor(state.reshape(1, -1)) # 确保 state 是 (1, state_dim) 的形状action, _ = self.actor.sample(state)return action.cpu().data.numpy().flatten()def update(self, replay_buffer, batch_size=256):state, action, next_state, reward, done = replay_buffer.sample(batch_size)state = torch.FloatTensor(state)action = torch.FloatTensor(action)next_state = torch.FloatTensor(next_state)reward = torch.FloatTensor(reward).unsqueeze(1)done = torch.FloatTensor(done).unsqueeze(1)with torch.no_grad():next_action, next_log_prob = self.actor.sample(next_state)target_q1 = self.target_critic1(next_state, next_action)target_q2 = self.target_critic2(next_state, next_action)target_q = torch.min(target_q1, target_q2) - self.log_alpha.exp() * next_log_probtarget_q = reward + (1 - done) * self.gamma * target_qcurrent_q1 = self.critic1(state, action)current_q2 = self.critic2(state, action)critic1_loss = F.mse_loss(current_q1, target_q)critic2_loss = F.mse_loss(current_q2, target_q)self.critic1_optimizer.zero_grad()critic1_loss.backward()self.critic1_optimizer.step()self.critic2_optimizer.zero_grad()critic2_loss.backward()self.critic2_optimizer.step()action_new, log_prob = self.actor.sample(state)q1_new = self.critic1(state, action_new)q2_new = self.critic2(state, action_new)q_new = torch.min(q1_new, q2_new)actor_loss = (self.log_alpha.exp() * log_prob - q_new).mean()self.actor_optimizer.zero_grad()actor_loss.backward()self.actor_optimizer.step()alpha_loss = -(self.log_alpha * (log_prob + 1).detach()).mean()self.alpha_optimizer.zero_grad()alpha_loss.backward()self.alpha_optimizer.step()for param, target_param in zip(self.critic1.parameters(), self.target_critic1.parameters()):target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)for param, target_param in zip(self.critic2.parameters(), self.target_critic2.parameters()):target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)# 简单的 Replay Buffer
class ReplayBuffer:def __init__(self, max_size=1e6):self.buffer = []self.max_size = int(max_size) # 将 max_size 转换为整数self.ptr = 0def add(self, state, action, next_state, reward, done):if len(self.buffer) < self.max_size:self.buffer.append(None)self.buffer[self.ptr] = (state, action, next_state, reward, done)self.ptr = (self.ptr + 1) % self.max_sizedef sample(self, batch_size):indices = np.random.randint(0, len(self.buffer), batch_size)states, actions, next_states, rewards, dones = [], [], [], [], []for idx in indices:state, action, next_state, reward, done = self.buffer[idx]states.append(state)actions.append(action)next_states.append(next_state)rewards.append(reward)dones.append(done)return np.array(states), np.array(actions), np.array(next_states), np.array(rewards), np.array(dones)# 训练 SAC 算法
env = gym.make('Pendulum-v1')
state_dim = env.observation_space.shape[0]
action_dim = env.action_space.shape[0]
max_action = float(env.action_space.high[0])sac = SAC(state_dim, action_dim, max_action)
replay_buffer = ReplayBuffer()max_episodes = 1000
batch_size = 256for episode in range(max_episodes):state = env.reset()if isinstance(state, tuple): # 如果返回的是元组,提取状态state = state[0]episode_reward = 0done = Falsewhile not done:env.render()action = sac.select_action(state)next_state, reward, done, info = env.step(action)replay_buffer.add(state, action, next_state, reward, done)state = next_stateepisode_reward += rewardif len(replay_buffer.buffer) > batch_size:sac.update(replay_buffer, batch_size)print(f"Episode {episode + 1}, Reward: {episode_reward}")env.close()
简介
Soft Actor-Critic (SAC) 是一种基于最大熵(Maximum Entropy)的深度强化学习算法,专为连续动作空间设计。它结合了 Actor-Critic 框架和熵正则化(Entropy Regularization),在探索与利用之间取得了良好的平衡。