强化学习-A2C

news/2025/2/12 1:00:29/

关于A2C的介绍可以参考书本158页

流程图
此处参考强化学习–从DQN到PPO, 流程详解
在这里插入图片描述在这里插入图片描述图片来源于博客强化学习之policy-based方法A2C实现(PyTorch)

代码实现
代码参考Actor-Critic-pytorch

import gym, os
from itertools import count
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.distributions import Categoricaldevice = torch.device("cuda" if torch.cuda.is_available() else "cpu")
env = gym.make("CartPole-v0").unwrappedstate_size = env.observation_space.shape[0]
action_size = env.action_space.n
lr = 0.0001class Actor(nn.Module):def __init__(self, state_size, action_size):super(Actor, self).__init__()self.state_size = state_sizeself.action_size = action_sizeself.linear1 = nn.Linear(self.state_size, 128)self.linear2 = nn.Linear(128, 256)self.linear3 = nn.Linear(256, self.action_size)def forward(self, state):output = F.relu(self.linear1(state))output = F.relu(self.linear2(output))output = self.linear3(output)distribution = Categorical(F.softmax(output, dim=-1))return distributionclass Critic(nn.Module):def __init__(self, state_size, action_size):super(Critic, self).__init__()self.state_size = state_sizeself.action_size = action_sizeself.linear1 = nn.Linear(self.state_size, 128)self.linear2 = nn.Linear(128, 256)self.linear3 = nn.Linear(256, 1)def forward(self, state):output = F.relu(self.linear1(state))output = F.relu(self.linear2(output))value = self.linear3(output)return valuedef compute_returns(next_value, rewards, masks, gamma=0.99):R = next_valuereturns = []for step in reversed(range(len(rewards))):R = rewards[step] + gamma * R * masks[step]returns.insert(0, R)return returnsdef trainIters(actor, critic, n_iters):optimizerA = optim.Adam(actor.parameters())optimizerC = optim.Adam(critic.parameters())for iter in range(n_iters):state = env.reset()log_probs = []values = []rewards = []masks = []entropy = 0env.reset()for i in count():env.render()state = torch.FloatTensor(state).to(device)dist, value = actor(state), critic(state)action = dist.sample()next_state, reward, done, _ = env.step(action.cpu().numpy())log_prob = dist.log_prob(action).unsqueeze(0)entropy += dist.entropy().mean()log_probs.append(log_prob)values.append(value)rewards.append(torch.tensor([reward], dtype=torch.float, device=device))masks.append(torch.tensor([1-done], dtype=torch.float, device=device))state = next_stateif done:print('Iteration: {}, Score: {}'.format(iter, i))breaknext_state = torch.FloatTensor(next_state).to(device)next_value = critic(next_state)returns = compute_returns(next_value, rewards, masks)log_probs = torch.cat(log_probs)returns = torch.cat(returns).detach()values = torch.cat(values)advantage = returns - valuesactor_loss = -(log_probs * advantage.detach()).mean()critic_loss = advantage.pow(2).mean()optimizerA.zero_grad()optimizerC.zero_grad()actor_loss.backward()critic_loss.backward()optimizerA.step()optimizerC.step()torch.save(actor, 'model/actor.pkl')torch.save(critic, 'model/critic.pkl')env.close()if __name__ == '__main__':if os.path.exists('model/actor.pkl'):actor = torch.load('model/actor.pkl')print('Actor Model loaded')else:actor = Actor(state_size, action_size).to(device)if os.path.exists('model/critic.pkl'):critic = torch.load('model/critic.pkl')print('Critic Model loaded')else:critic = Critic(state_size, action_size).to(device)trainIters(actor, critic, n_iters=100)

http://www.ppmy.cn/news/880856.html

相关文章

REINFORCE和A2C的异同

两者的神经网络结构一模一样,都是分为两个网络,即策略神经网络和价值神经网络。但是两者的区别在于价值神经网络的作用不同,A2C中的可以评价当前状态的好坏,而REINFORCE中的只是作为一个Baseline而已,唯一作用就是降低…

Actor-Critic(A2C)算法 原理讲解+pytorch程序实现

文章目录 1 前言2 算法简介3 原理推导4 程序实现5 优缺点分析6 使用经验7 总结 1 前言 强化学习在人工智能领域中具有广泛的应用,它可以通过与环境互动来学习如何做出最佳决策。本文将介绍一种常用的强化学习算法:Actor-Critic并且附上基于pytorch实现的…

A2C算法原理及代码实现

本文主要参考王树森老师的强化学习课程 1.A2C算法原理 A2C算法是策略学习中比较经典的一个算法,是在 Barto 等人1983年提出的。我们知道策略梯度方法用策略梯度更新策略网络参数 θ,从而增大目标函数,即下面的随机梯度: Actor-C…

强化学习算法A2C(Advantage Actor-Critic)和A3C(Asynchronous Advantage Actor-Critic)算法详解以及A2C的Pytorch实现

一、策略梯度算法回顾 策略梯度(Policy Gradient)算法目标函数的梯度更新公式为: ▽ R ˉ θ 1 N ∑ n 1 N ∑ t 1 T n ( ∑ t ′ t T n γ t ′ − t r t ′ n − b ) ▽ l o g p θ ( a t n ∣ s t n ) (1) \bigtriangledown \bar{R}…

Unity 3D 脚本编程与游戏开发 学习笔记

学习笔记 内容提要Unity脚本概览控制物体移动触发器事件 Unity 基本概念与脚本编程物体、组件和对象创建物体实例——3D射击游戏 内容提要 全书从建立编程脚本和游戏框架为出发点,逐步阐述游戏开发中的核心概念,核心的物理系统和数学基础,然…

【Rust 基础篇】Rust 自定义迭代器

导言 在 Rust 中,自定义迭代器可以帮助我们根据特定需求实现符合自己逻辑的迭代过程。自定义迭代器是通过实现 Iterator trait 来完成的。本篇博客将详细介绍如何在 Rust 中自定义迭代器,包括自定义迭代器的定义、必要的方法和一些常见的使用场景。 自…

解“冰刃”的使用方法

冰刃——IceSWord是一斩断黑手的利刃 。它适用于windows 2000/XP/2003操作系统,用于查探系统中的幕后黑手(木马后门)并作出处理,当然使用它需要用户有一些操作系统的知识。  在对软件做讲解之前,首先说明第一注意事项:此程序运行…

华硕ROG冰刃6双屏原厂预装Windows11系统工厂恢复带ASUSRecovery恢复功能

华硕工厂恢复系统 ,安装结束后带隐藏分区以及机器所有驱动软件,奥创 文件地址: https://pan.baidu.com/s/1Pq09oDzmFI6hXVdf8Vqjqw?pwd3fs8 提取码:3fs8 文件格式:5个底包(HDI KIT COM MCAFEE EDN) 1个引导工具TLK 支持ASUSRECOVERY型号 冰刃7双屏…