基于“动手学强化学习”的知识点(一):第 14 章 SAC 算法(gym版本 >= 0.26)

server/2025/3/15 23:10:29/

第 14 章 SAC 算法(gym版本 >= 0.26)

  • 摘要
  • SAC 算法(连续)
  • SAC 算法(离散)

摘要

本系列知识点讲解基于动手学强化学习中的内容进行详细的疑难点分析!具体内容请阅读动手学强化学习


对应动手学强化学习——SAC 算法


SAC 算法(连续)

# -*- coding: utf-8 -*-import random
import gym
import numpy as np
from tqdm import tqdm
import torch
import torch.nn.functional as F
from torch.distributions import Normal
import matplotlib.pyplot as plt
import rl_utilsclass PolicyNetContinuous(torch.nn.Module):def __init__(self, state_dim, hidden_dim, action_dim, action_bound):super(PolicyNetContinuous, self).__init__()self.fc1 = torch.nn.Linear(state_dim, hidden_dim)self.fc_mu = torch.nn.Linear(hidden_dim, action_dim)self.fc_std = torch.nn.Linear(hidden_dim, action_dim)'''作用:保存动作幅度的界限,便于后续对动作做缩放。数值例子:若 action_bound=2,最终动作将会在 [-2, 2] 范围内。'''self.action_bound = action_bounddef forward(self, x):x = F.relu(self.fc1(x))mu = self.fc_mu(x)std = F.softplus(self.fc_std(x))'''作用:使用上面计算得到的 mu 和 std 构造正态分布对象 dist。数值例子:- 这时构造的分布为 𝑁(0.8,0.474^2)。'''dist = Normal(mu, std)'''作用:从正态分布中采样,但采用“重参数化采样”(rsample),以便后续能对采样过程进行梯度反传。数值例子:- 例如,若采样时随机变量 ε 从标准正态分布中取到 0.3,则采样值为 0.8 + 0.474 * 0.3 ≈ 0.8 + 0.1422 = 0.9422。'''normal_sample = dist.rsample()  # rsample()是重参数化采样'''作用:计算刚采样值在原始正态分布下的对数概率密度。'''log_prob = dist.log_prob(normal_sample)'''作用:对采样的原始动作进行 tanh 激活,将其映射到 (-1, 1) 范围内,保证动作平滑且有界。数值例子:- 对于采样值 0.9422,torch.tanh(0.9422) ≈ 0.737。'''action = torch.tanh(normal_sample)# 计算tanh_normal分布的对数概率密度'''作用:由于经过了 tanh 非线性变换,原来的对数概率密度需要进行修正(Jacobian 修正项),这里用公式logp_action=logp_normal−log(1−tanh(action)^2+ϵ)注意:实际应用中,通常是对 normal_sample 进行修正,写法可能略有不同,但这里的目标是一致的——补偿 tanh 变换带来的概率密度变换。'''log_prob = log_prob - torch.log(1 - torch.tanh(action).pow(2) + 1e-7)action = action * self.action_boundreturn action, log_probclass QValueNetContinuous(torch.nn.Module):def __init__(self, state_dim, hidden_dim, action_dim):super(QValueNetContinuous, self).__init__()self.fc1 = torch.nn.Linear(state_dim + action_dim, hidden_dim)self.fc2 = torch.nn.Linear(hidden_dim, hidden_dim)self.fc_out = torch.nn.Linear(hidden_dim, 1)def forward(self, x, a):cat = torch.cat([x, a], dim=1)x = F.relu(self.fc1(cat))x = F.relu(self.fc2(x))return self.fc_out(x)class SACContinuous:''' 处理连续动作的SAC算法 '''"""解释:- 定义一个名为 SACContinuous 的类,用来实现针对连续动作的 Soft Actor-Critic 算法。"""def __init__(self, state_dim, hidden_dim, action_dim, action_bound,actor_lr, critic_lr, alpha_lr, target_entropy, tau, gamma,device):"""定义构造函数,接收一系列超参数,分别代表状态维度、隐藏层神经元个数、动作维度、动作界限、各网络的学习率、目标熵、软更新参数、折扣因子和设备。"""'''# 策略网络使用前面定义的 PolicyNetContinuous 构造函数生成策略网络(actor),并将该网络放到指定设备上(例如 CPU 或 GPU)。'''self.actor = PolicyNetContinuous(state_dim, hidden_dim, action_dim, action_bound).to(device)  '''# 第一个Q网络创建第一个 Q 网络,用于评估(状态,动作)对的价值,同样放到指定设备。'''self.critic_1 = QValueNetContinuous(state_dim, hidden_dim, action_dim).to(device) '''# 第二个Q网络创建第二个 Q 网络,与第一个结构相同,用于双重估计,帮助缓解过估计问题。'''self.critic_2 = QValueNetContinuous(state_dim, hidden_dim, action_dim).to(device)  '''# 第一个目标Q网络构造第一个目标 Q 网络,其结构与 critic_1 相同,用于计算目标值(TD目标),以便实现平滑更新。'''self.target_critic_1 = QValueNetContinuous(state_dim, hidden_dim, action_dim).to(device) '''# 第二个目标Q网络构造第二个目标 Q 网络,其结构与 critic_2 相同,用于目标值计算。'''self.target_critic_2 = QValueNetContinuous(state_dim, hidden_dim, action_dim).to(device)  '''# 令目标Q网络的初始参数和Q网络一样将 critic_1 网络的所有参数复制到 target_critic_1 中,使二者初始时完全一致。将 critic_2 网络的所有参数复制到 target_critic_2 中,使二者初始时完全一致。'''self.target_critic_1.load_state_dict(self.critic_1.state_dict())self.target_critic_2.load_state_dict(self.critic_2.state_dict())'''使用 Adam 优化器为策略网络分配优化器,学习率为 actor_lr。'''self.actor_optimizer = torch.optim.Adam(self.actor.parameters(), lr=actor_lr)'''为 critic_1 分配 Adam 优化器,学习率为 critic_lr。为 critic_2 分配 Adam 优化器,学习率为 critic_lr。'''self.critic_1_optimizer = torch.optim.Adam(self.critic_1.parameters(), lr=critic_lr)self.critic_2_optimizer = torch.optim.Adam(self.critic_2.parameters(), lr=critic_lr)# 使用alpha的log值,可以使训练结果比较稳定'''创建一个标量张量,用于存储温度参数 alpha 的对数值。初始值设为 log(0.01) ≈ -4.6052。这样做有助于稳定训练,因为直接优化正数会带来数值不稳定问题。'''self.log_alpha = torch.tensor(np.log(0.01), dtype=torch.float)'''设置该张量的 requires_grad 属性为 True,表示在反向传播时会计算关于 log_alpha 的梯度,从而能更新温度参数。'''self.log_alpha.requires_grad = True  # 可以对alpha求梯度'''为 log_alpha 创建一个 Adam 优化器,学习率为 alpha_lr。注意优化器接收的是一个包含 log_alpha 的列表。'''self.log_alpha_optimizer = torch.optim.Adam([self.log_alpha], lr=alpha_lr)'''保存目标熵参数,这个值用于指导策略更新时保持足够的探索性。'''self.target_entropy = target_entropy  # 目标熵的大小'''保存折扣因子,用于计算未来奖励的折现值。'''self.gamma = gamma'''保存软更新系数 tau,用于更新目标网络的参数。'''self.tau = tau'''保存设备信息,便于后续将数据和模型都放到同一设备上。'''self.device = devicedef take_action(self, state):"""定义一个方法,根据当前状态输出一个动作(供环境交互时调用)。"""'''将传入的状态(例如一个列表或数组)转换为 PyTorch 张量,并在外面加一层列表以增加 batch 维度,然后将其放到指定设备上。state = [1,2,3,4]state = torch.tensor([state], dtype=torch.float).to("cuda")print(state) # tensor([[1., 2., 3., 4.]], device='cuda:0')state1 = [1,2,3,4]state1 = torch.tensor(state1, dtype=torch.float).to("cuda")print(state1) # tensor([1., 2., 3., 4.], device='cuda:0')state2 = [1,2,3,4]state2 = torch.tensor(state2, dtype=torch.float).unsqueeze(0).to("cuda")print(state2) # tensor([[1., 2., 3., 4.]], device='cuda:0')'''if isinstance(state, tuple):state = state[0]state = torch.tensor([state], dtype=torch.float).to(self.device)'''解释:- 将状态输入 actor 网络,得到输出。由于 actor 的 forward 返回的是一个元组(动作、对数概率),这里取第一个元素(动作部分)。数值例子:- 假设 actor 返回 (tensor([[0.737]]), tensor([[-0.45]])),则 action = tensor([[0.737]]);再取 [0] 后得到单个样本的动作张量 tensor([0.737])。'''action = self.actor(state)[0]'''解释:- 将动作张量转换为 Python 标量,并放入列表后返回。数值例子:- action.item() 会返回 0.737,最终返回 [0.737]。这样可以适应环境要求动作为列表格式的情况。'''return [action.item()]def calc_target(self, rewards, next_states, dones):  # 计算目标Q值"""定义一个方法,利用下一时刻状态、奖励和 done 标志计算 TD 目标(目标 Q 值),用于 critic 网络的回归训练。"""'''对所有下一状态(通常是一个 batch),利用 actor 网络计算下一时刻动作和其对应的对数概率。数值例子:假设 next_states 有 2 个样本,每个样本状态为 3 维;actor 返回- next_actions = tensor([[1.2], [0.8]])- log_prob = tensor([[-0.5], [-0.6]])'''next_actions, log_prob = self.actor(next_states)'''计算熵项,实际上熵等于负的对数概率。数值例子:如果 log_prob = tensor([[-0.5], [-0.6]]),则 entropy = tensor([[0.5], [0.6]])。'''entropy = -log_prob'''使用目标网络1计算给定下一状态和对应动作的 Q 值。'''q1_value = self.target_critic_1(next_states, next_actions)'''使用目标网络2计算给定下一状态和对应动作的 Q 值。'''q2_value = self.target_critic_2(next_states, next_actions)'''解释:- 首先,取两个目标 Q 值的最小值(用来降低过估计风险);- 然后加上温度参数 alpha(由 self.log_alpha.exp() 得到)乘以熵项,这一项鼓励探索。数值例子:- 对第一样本:min(2.0, 2.5) = 2.0,且 self.log_alpha.exp() 计算为 exp(-4.6052) ≈ 0.01;熵为 0.5,则 next_value = 2.0 + 0.01 * 0.5 = 2.0 + 0.005 = 2.005。- 对第二样本:min(3.0, 3.5) = 3.0,熵为 0.6,则 next_value = 3.0 + 0.01 * 0.6 = 3.0 + 0.006 = 3.006。'''next_value = torch.min(q1_value, q2_value) + self.log_alpha.exp() * entropy'''计算 TD 目标:td_target = 𝑟 + 𝛾 × next_value × (1−done)当 done 为 1(表示回合结束)时,不再折扣未来奖励。'''td_target = rewards + self.gamma * next_value * (1 - dones)return td_targetdef soft_update(self, net, target_net):"""定义一个方法,用于对目标网络参数做软更新。传入当前网络和对应的目标网络。"""'''遍历目标网络和当前网络中对应的每一对参数(权重和偏置)。'''for param_target, param in zip(target_net.parameters(), net.parameters()):'''对每个参数做软更新:𝜃target←(1−𝜏)𝜃target+𝜏𝜃θtarget←(1−τ)θ target+τθ这可以平滑地将目标网络参数向当前网络参数靠拢。'''param_target.data.copy_(param_target.data * (1.0 - self.tau) + param.data * self.tau)def update(self, transition_dict):"""定义一个方法,根据从 replay buffer 中采样的转换数据(transition)更新 actor、critic 网络以及温度参数 alpha。"""'''将 transition_dict 中的状态数据转换为浮点型张量,并放到指定设备上。'''states = torch.tensor(transition_dict['states'], dtype=torch.float).to(self.device)'''同理,将动作数据转换为张量,并通过 view(-1, 1) 调整形状为 (batch_size, 1)(即每个动作为一个标量)。数值例子:若 transition_dict['actions'] = [1.0, 0.5],转换后形状为 (2, 1)。'''actions = torch.tensor(transition_dict['actions'], dtype=torch.float).view(-1, 1).to(self.device)'''将奖励数据转换为形状为 (batch_size, 1) 的张量。数值例子:若 rewards = [1.0, -0.5],则转换后为形状 (2, 1)。'''rewards = torch.tensor(transition_dict['rewards'], dtype=torch.float).view(-1, 1).to(self.device)'''将下一时刻的状态数据转换为张量。数值例子:例如 next_states = [[1.1, 0.4, -0.1], [0.2, 0.0, 0.9]],形状 (2, 3)。'''next_states = torch.tensor(transition_dict['next_states'], dtype=torch.float).to(self.device)'''将 done 标志(0 或 1)转换为形状为 (batch_size, 1) 的张量,用于指示回合是否结束。数值例子:若 dones = [0, 1],则转换后为 tensor([[0.0], [1.0]])。'''dones = torch.tensor(transition_dict['dones'], dtype=torch.float).view(-1, 1).to(self.device)# 和之前章节一样,对倒立摆环境的奖励进行重塑以便训练'''对奖励进行归一化或重塑,使其数值范围更适合训练。对于倒立摆(或类似)环境,原始奖励可能范围较大,这里将所有奖励平移 8.0 后除以 8.0。数值例子:- 如果原始 reward = -8.0,则 ( -8.0 + 8.0) / 8.0 = 0;- 如果 reward = 0,则变为 1;- 如果 reward = 8,则变为 2。'''rewards = (rewards + 8.0) / 8.0# 更新两个Q网络'''调用前面定义的 calc_target 方法,根据重塑后的奖励、下一状态和 done 标志计算 TD 目标。'''td_target = self.calc_target(rewards, next_states, dones)'''计算 critic_1 的均方误差(MSE)损失。- 调用 self.critic_1(states, actions) 得到当前 Q 值估计;- 使用 td_target.detach() 表示目标值不参与梯度计算;- 用 MSE 损失函数计算误差,再取平均。数值例子:- 假设 critic_1 输出 Q 值为 2.5,td_target 为 2.98,则误差为 (2.5−2.98)^2≈0.2304;对 batch 求均值。'''critic_1_loss = torch.mean(F.mse_loss(self.critic_1(states, actions), td_target.detach()))critic_2_loss = torch.mean(F.mse_loss(self.critic_2(states, actions), td_target.detach()))'''清空 critic_1 优化器中所有累积的梯度,防止梯度累加。'''self.critic_1_optimizer.zero_grad()'''对 critic_1 损失进行反向传播,计算每个参数的梯度。'''critic_1_loss.backward()'''更新 critic_1 网络的参数,根据之前计算的梯度和设定的学习率进行一步更新。'''self.critic_1_optimizer.step()'''清空 critic_2 优化器中所有累积的梯度,防止梯度累加。'''self.critic_2_optimizer.zero_grad()'''对 critic_2 损失进行反向传播,计算每个参数的梯度。'''critic_2_loss.backward()'''更新 critic_2 网络的参数,根据之前计算的梯度和设定的学习率进行一步更新。'''self.critic_2_optimizer.step()# 更新策略网络'''使用当前 actor 网络,根据当前状态生成一组新的动作及其对数概率,用于策略更新。'''new_actions, log_prob = self.actor(states)'''计算熵项,即负的对数概率。'''entropy = -log_prob'''用当前 critic_1 网络评估新生成动作的 Q 值。'''q1_value = self.critic_1(states, new_actions)'''用当前 critic_2 网络评估新生成动作的 Q 值。'''q2_value = self.critic_2(states, new_actions)'''计算策略网络(actor)的损失。- 第一项:−𝛼 × entropy 用于鼓励策略探索;- 第二项:−min(𝑞1, 𝑞2) 表示希望选择高价值动作;- 取均值作为整个 batch 的损失。'''actor_loss = torch.mean(-self.log_alpha.exp() * entropy - torch.min(q1_value, q2_value))'''对 actor 网络进行梯度清零、反向传播和参数更新。'''self.actor_optimizer.zero_grad()actor_loss.backward()self.actor_optimizer.step()# 更新alpha值'''计算温度参数 alpha 的损失。- (entropy - self.target_entropy) 表示当前熵与目标熵之间的偏差;- 用 detach() 阻断梯度传递给 entropy(即仅更新 alpha);- 乘以当前的 𝛼 = exp{log(𝛼);- 取均值作为总体损失。'''alpha_loss = torch.mean((entropy - self.target_entropy).detach() * self.log_alpha.exp())'''清空 log_alpha 的梯度、反向传播损失并更新 log_alpha 参数。'''self.log_alpha_optimizer.zero_grad()alpha_loss.backward()self.log_alpha_optimizer.step()'''调用之前定义的 soft_update 方法,对两个目标 Q 网络分别做软更新,使得目标网络参数慢慢跟随当前 Q 网络的更新。'''self.soft_update(self.critic_1, self.target_critic_1)self.soft_update(self.critic_2, self.target_critic_2)env_name = 'Pendulum-v1'
env = gym.make(env_name)
state_dim = env.observation_space.shape[0]
action_dim = env.action_space.shape[0]
action_bound = env.action_space.high[0]  # 动作最大值
random.seed(0)
np.random.seed(0)
if not hasattr(env, 'seed'):def seed_fn(self, seed=None):env.reset(seed=seed)return [seed]env.seed = seed_fn.__get__(env, type(env))
# env.seed(0)
torch.manual_seed(0)actor_lr = 3e-4
critic_lr = 3e-3
alpha_lr = 3e-4
num_episodes = 100
hidden_dim = 128
gamma = 0.99
tau = 0.005  # 软更新参数
buffer_size = 100000
minimal_size = 1000
batch_size = 64
target_entropy = -env.action_space.shape[0]
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")replay_buffer = rl_utils.ReplayBuffer(buffer_size)
agent = SACContinuous(state_dim, hidden_dim, action_dim, action_bound,actor_lr, critic_lr, alpha_lr, target_entropy, tau,gamma, device)return_list = rl_utils.train_off_policy_agent(env, agent, num_episodes,replay_buffer, minimal_size,batch_size)    episodes_list = list(range(len(return_list)))
plt.plot(episodes_list, return_list)
plt.xlabel('Episodes')
plt.ylabel('Returns')
plt.title('SAC on {}'.format(env_name))
plt.show()mv_return = rl_utils.moving_average(return_list, 9)
plt.plot(episodes_list, mv_return)
plt.xlabel('Episodes')
plt.ylabel('Returns')
plt.title('SAC on {}'.format(env_name))
plt.show()

SAC 算法(离散)

# -*- coding: utf-8 -*-import random
import gym
import numpy as np
from tqdm import tqdm
import torch
import torch.nn.functional as F
from torch.distributions import Normal
import matplotlib.pyplot as plt
import rl_utilsclass PolicyNet(torch.nn.Module):def __init__(self, state_dim, hidden_dim, action_dim):super(PolicyNet, self).__init__()self.fc1 = torch.nn.Linear(state_dim, hidden_dim)self.fc2 = torch.nn.Linear(hidden_dim, action_dim)def forward(self, x):x = F.relu(self.fc1(x))return F.softmax(self.fc2(x), dim=1)class QValueNet(torch.nn.Module):''' 只有一层隐藏层的Q网络 '''def __init__(self, state_dim, hidden_dim, action_dim):super(QValueNet, self).__init__()self.fc1 = torch.nn.Linear(state_dim, hidden_dim)self.fc2 = torch.nn.Linear(hidden_dim, action_dim)def forward(self, x):x = F.relu(self.fc1(x))return self.fc2(x)class SAC:''' 处理离散动作的SAC算法 '''def __init__(self, state_dim, hidden_dim, action_dim, actor_lr, critic_lr,alpha_lr, target_entropy, tau, gamma, device):# 策略网络self.actor = PolicyNet(state_dim, hidden_dim, action_dim).to(device)# 第一个Q网络self.critic_1 = QValueNet(state_dim, hidden_dim, action_dim).to(device)# 第二个Q网络self.critic_2 = QValueNet(state_dim, hidden_dim, action_dim).to(device)self.target_critic_1 = QValueNet(state_dim, hidden_dim,action_dim).to(device)  # 第一个目标Q网络self.target_critic_2 = QValueNet(state_dim, hidden_dim,action_dim).to(device)  # 第二个目标Q网络# 令目标Q网络的初始参数和Q网络一样self.target_critic_1.load_state_dict(self.critic_1.state_dict())self.target_critic_2.load_state_dict(self.critic_2.state_dict())self.actor_optimizer = torch.optim.Adam(self.actor.parameters(),lr=actor_lr)self.critic_1_optimizer = torch.optim.Adam(self.critic_1.parameters(),lr=critic_lr)self.critic_2_optimizer = torch.optim.Adam(self.critic_2.parameters(),lr=critic_lr)# 使用alpha的log值,可以使训练结果比较稳定self.log_alpha = torch.tensor(np.log(0.01), dtype=torch.float)self.log_alpha.requires_grad = True  # 可以对alpha求梯度self.log_alpha_optimizer = torch.optim.Adam([self.log_alpha],lr=alpha_lr)self.target_entropy = target_entropy  # 目标熵的大小self.gamma = gammaself.tau = tauself.device = devicedef take_action(self, state):if isinstance(state, tuple):state = state[0]state = torch.tensor([state], dtype=torch.float).to(self.device)probs = self.actor(state)action_dist = torch.distributions.Categorical(probs)action = action_dist.sample()return action.item()# 计算目标Q值,直接用策略网络的输出概率进行期望计算def calc_target(self, rewards, next_states, dones):next_probs = self.actor(next_states)next_log_probs = torch.log(next_probs + 1e-8)entropy = -torch.sum(next_probs * next_log_probs, dim=1, keepdim=True)q1_value = self.target_critic_1(next_states)q2_value = self.target_critic_2(next_states)min_qvalue = torch.sum(next_probs * torch.min(q1_value, q2_value),dim=1,keepdim=True)next_value = min_qvalue + self.log_alpha.exp() * entropytd_target = rewards + self.gamma * next_value * (1 - dones)return td_targetdef soft_update(self, net, target_net):for param_target, param in zip(target_net.parameters(),net.parameters()):param_target.data.copy_(param_target.data * (1.0 - self.tau) +param.data * self.tau)def update(self, transition_dict):states = torch.tensor(transition_dict['states'],dtype=torch.float).to(self.device)actions = torch.tensor(transition_dict['actions']).view(-1, 1).to(self.device)  # 动作不再是float类型rewards = torch.tensor(transition_dict['rewards'],dtype=torch.float).view(-1, 1).to(self.device)next_states = torch.tensor(transition_dict['next_states'],dtype=torch.float).to(self.device)dones = torch.tensor(transition_dict['dones'],dtype=torch.float).view(-1, 1).to(self.device)# 更新两个Q网络td_target = self.calc_target(rewards, next_states, dones)critic_1_q_values = self.critic_1(states).gather(1, actions)critic_1_loss = torch.mean(F.mse_loss(critic_1_q_values, td_target.detach()))critic_2_q_values = self.critic_2(states).gather(1, actions)critic_2_loss = torch.mean(F.mse_loss(critic_2_q_values, td_target.detach()))self.critic_1_optimizer.zero_grad()critic_1_loss.backward()self.critic_1_optimizer.step()self.critic_2_optimizer.zero_grad()critic_2_loss.backward()self.critic_2_optimizer.step()# 更新策略网络probs = self.actor(states)log_probs = torch.log(probs + 1e-8)# 直接根据概率计算熵entropy = -torch.sum(probs * log_probs, dim=1, keepdim=True)  #q1_value = self.critic_1(states)q2_value = self.critic_2(states)min_qvalue = torch.sum(probs * torch.min(q1_value, q2_value),dim=1,keepdim=True)  # 直接根据概率计算期望actor_loss = torch.mean(-self.log_alpha.exp() * entropy - min_qvalue)self.actor_optimizer.zero_grad()actor_loss.backward()self.actor_optimizer.step()# 更新alpha值alpha_loss = torch.mean((entropy - target_entropy).detach() * self.log_alpha.exp())self.log_alpha_optimizer.zero_grad()alpha_loss.backward()self.log_alpha_optimizer.step()self.soft_update(self.critic_1, self.target_critic_1)self.soft_update(self.critic_2, self.target_critic_2)    actor_lr = 1e-3
critic_lr = 1e-2
alpha_lr = 1e-2
num_episodes = 200
hidden_dim = 128
gamma = 0.98
tau = 0.005  # 软更新参数
buffer_size = 10000
minimal_size = 500
batch_size = 64
target_entropy = -1
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")env_name = 'CartPole-v0'
env = gym.make(env_name)
random.seed(0)
np.random.seed(0)
if not hasattr(env, 'seed'):def seed_fn(self, seed=None):env.reset(seed=seed)return [seed]env.seed = seed_fn.__get__(env, type(env))
# env.seed(0)
torch.manual_seed(0)
replay_buffer = rl_utils.ReplayBuffer(buffer_size)
state_dim = env.observation_space.shape[0]
action_dim = env.action_space.n
agent = SAC(state_dim, hidden_dim, action_dim, actor_lr, critic_lr, alpha_lr,target_entropy, tau, gamma, device)return_list = rl_utils.train_off_policy_agent(env, agent, num_episodes,replay_buffer, minimal_size,batch_size)episodes_list = list(range(len(return_list)))
plt.plot(episodes_list, return_list)
plt.xlabel('Episodes')
plt.ylabel('Returns')
plt.title('SAC on {}'.format(env_name))
plt.show()mv_return = rl_utils.moving_average(return_list, 9)
plt.plot(episodes_list, mv_return)
plt.xlabel('Episodes')
plt.ylabel('Returns')
plt.title('SAC on {}'.format(env_name))
plt.show()

http://www.ppmy.cn/server/175278.html

相关文章

手势调控屏幕亮度:Python + OpenCV + Mediapipe 打造智能交互体验

前言 你有没有遇到过这样的情况? 夜晚玩电脑,屏幕亮得像个小太阳,晃得眼泪直流,想调暗一点,却在键盘上盲摸半天,结果误触关机键,直接黑屏;白天屏幕暗得像熄火的煤油灯,想调亮点,鼠标点来点去,调节条藏得像猫一样不见踪影。这年头,我们的设备都快能听懂人话了,怎…

【RK3588嵌入式图形编程】-SDL2-构建一个多功能的图像类

构建一个多功能的图像类 文章目录 构建一个多功能的图像类1、概述2、设计原则2.1 友好API2.2 性能2.3 反馈2.4 破坏性变更和可扩展性3、加载文件4、源矩形5、目标矩形6、渲染和缩放模式7、完整代码8、总结本文将详细介绍如何设计一个灵活的组件,方便SDL的应用程序中处理图像。…

Java 8 + Tomcat 9.0.102 的稳定环境搭建方案,适用于生产环境

一、安装 Java 8 安装 OpenJDK 8 bash sudo apt update sudo apt install openjdk-8-jdk -y 验证安装 bash java -version 应输出类似: openjdk version “1.8.0_412” OpenJDK Runtime Environment (build 1.8.0_412-8u412-ga-1~22.04-b08) OpenJDK 64-Bit Server VM (bui…

【Rust】枚举和模式匹配——Rust语言基础14

文章目录 1. 枚举类型1.2. Option 枚举 2. match 控制流结构2.1. match 对绑定值的匹配2.2. Option<T> 的匹配2.3. 通配模式以及 _ 占位符 3. if let 控制流4. 小测试 1. 枚举类型 枚举&#xff08;enumerations&#xff09;&#xff0c;也被称作 enums。枚举允许你通过…

LabVIEW旋转设备状态在线监测系统

为了提高大型旋转设备如电机和水泵的监控效率和故障诊断能力&#xff0c;用LabVIEW软件开发了一套实时监测与故障诊断系统。该系统集成了趋势分析、振动数据处理等多项功能&#xff0c;可实时分析电机电流、压力、温度及振动数据&#xff0c;以早期识别和预报故障。 ​ 项目背…

【学习笔记】语言模型的发展历程

语言模型的发展大致经历了以下四个阶段 统计语言模型(SLM) 主要建立在统计学习的理论框架下&#xff0c;尝试解决的是如下问题 p ( x t ∣ x 1 , x 2 , … x t − 1 ) (1) p(x_{t}|x_{1},x_{2},\dots x_{t-1})\tag{1} p(xt​∣x1​,x2​,…xt−1​)(1) 根据之前的历史信息预…

Django-ORM-select_related

Django-ORM-select_related 作用使用场景示例无 select_related 的查询有 select_related 的查询 如何理解 "只发起一次查询&#xff0c;包含所有相关作者信息"1. select_related 的工作原理2. 具体示例解析3. 为什么只发起一次查询 数据库中的books量巨大&#xff0…

SpringMVC(三)响应处理

目录 响应数据类型&#xff1a; 一、自动 JSON 响应 1 实现解析 二、文件下载 1 核心实现 2 优化与问题 响应数据类型&#xff1a; 一、自动 JSON 响应 1 实现解析 RestController 作用 类注解&#xff0c;自动将方法返回值序列化为 JSON&#xff08;无需 ResponseBody …