Stable Baselines/RL算法/A2C

news/2025/2/12 0:42:57/

Stable Baselines官方文档中文版 Github CSDN
尝试翻译官方文档,水平有限,如有错误万望指正

Asynchronous Advantage Actor Critic (A3C)的同步、确定性变体。它使用多个workers来避免使用重播缓存。

  • 要点核心

    • 原始文献: https://arxiv.org/abs/1602.01783
    • OpenAI 博客: https://openai.com/blog/baselines-acktr-a2c/
    • python -m stable_baselines.a2c.run_atariAtari游戏以 40M frames = 10M timesteps运行算法。更多选项参见帮助文档(-h
    • python -m stable_baselines.a2c.run_mujocoMujoco环境以1M frames运行算法
  • 适用情况

    • 迭代策略:✔️

    • 多进程:✔️

    • Gym空间:

      SpaceActionObservation
      Discrete✔️✔️
      Box✔️✔️
      MultiDiscrete✔️✔️
      MultiBinary✔️✔️
  • 案例

    用4进程在CartPole-v1上训练A2C agent

    import gymfrom stable_baselines.common.policies import MlpPolicy
    from stable_baselines.common.vec_env import SubprocVecEnv
    from stable_baselines import A2C# multiprocess environment
    n_cpu = 4
    env = SubprocVecEnv([lambda: gym.make('CartPole-v1') for i in range(n_cpu)])model = A2C(MlpPolicy, env, verbose=1)
    model.learn(total_timesteps=25000)
    model.save("a2c_cartpole")del model # remove to demonstrate saving and loadingmodel = A2C.load("a2c_cartpole")obs = env.reset()
    while True:action, _states = model.predict(obs)obs, rewards, dones, info = env.step(action)env.render()
    
  • 参数

    stable_baselines.a2c.A2C(policy, env, gamma=0.99, n_steps=5, vf_coef=0.25, ent_coef=0.01, max_grad_norm=0.5, learning_rate=0.0007, alpha=0.99, epsilon=1e-05, lr_schedule='constant', verbose=0, tensorboard_log=None, _init_setup_model=True, policy_kwargs=None, full_tensorboard_log=False)
    

    A2C(Adavantage Actor Critic)模型类, https://arxiv.org/abs/1602.01783

    参数数据类型意义
    policyActorCriticPolicy or str所用策略模型(MlpPolicy, CnnPolicy, CnnLstmPolicy, …)
    envGym environment or str学习所用环境(如果注册在Gym,可以是str)
    gammafloat贴现因子
    n_stepsint运行环境每次更新所用时间步(例如:当n_env是同时运行的环境副本数量时,batch=n_steps*n_env)
    vf_coeffloat用于损失函数的价值函数系数
    ent_coeffloat损失函数的信息熵系数
    max_grad_normfloat梯度裁剪的最大值
    learning_ratefloat学习率
    alphafloatRMSProp衰减参数(默认:0.99)
    epsilonfloatRMSProp epsilon(稳定RMSProp更新中分母的平方根计算)(默认1e-5)
    lr_schedulestr更新学习率的调度程序类型(‘linear’, ‘constant’, ‘double_linear_con’, ‘middle_drop’ or ‘double_middle_drop’)
    verboseint日志信息级别:0None;1训练信息;2tensorflow调试
    tensorboard_logstrtensorboard的日志位置(如果时None,没有日志)
    _init_setup_modelbool实例化创建过程中是否建立网络(只用于载入)
    policy_kwargsdict创建过程中传递给策略的额外参数
    full_tensorboard_logbool当使用tensorboard时,是否记录额外日志(这个日志会占用大量空间)
    • action_probability(observation, state=None, mask=None, actions=None, logp=False)

      如果actionsNone,那么从给定观测中获取模型的行动概率分布。

      输出取决于行动空间:

      • 离散:每个可能行动的概率
      • Box:行动输出的均值和标准差

      然而,如果actions不是None,这个函数会返回给定行动与参数(观测,状态,…)用于此模型的概率。对于离散行动空间,它返回概率密度;对于连续行动空间,则是概率密度。这是因为在连续空间,概率密度总是0,更详细的解释见 http://blog.christianperone.com/2019/01/

      参数数据类型意义
      observationnp.ndarray输入观测
      statenp.ndarray最新状态(可以时None,用于迭代策略)
      masknp.ndarray最新掩码(可以时None,用于迭代策略)
      actionsnp.ndarray(可选参数)为计算模型为每个给定参数选择给定行动的似然。行动和观测必须具有相同数目(None返回完全动作分布概率)
      logpbool(可选参数)当指定行动,返回log空间的概率。如果action是None,则此参数无效

      返回:np.ndarray)模型的(log)行动概率

    • get_env()

      返回当前环境(如果没有定义可以是None

      返回:Gym Environment)当前环境

    • get_parameter_list()

      获取模型参数的tensorflow变量

      包含连续训练(保存/载入)所用的所有必要变量

      返回:listtensorflow变量列表

    • get_parameters()

      获取当前模型参数作为变量名字典 -> ndarray

      返回:(OrderedDict)变量名字典 -> 模型参数的ndarray

    • learn(total_timesteps, callback=None, seed=None, log_interval=100, tb_log_name=‘A2C’, reset_num_timesteps=True)

      返回一个训练好的模型

      参数数据类型意义
      total_timestepsint训练用样本总数
      seedint训练用初始值,如果None:保持当前种子
      callbackfunction (dict, dict)与算法状态的每步调用的布尔函数。采用局部或全局变量。如果它返回false,训练被终止
      log_intervalint记录日志之前的时间步数
      tb_log_namestr运行tensorboard日志的名称
      reset_num_timestepsbool是否重置当前时间步数(日志中使用)

      返回:(BaseRLModel) 训练好的模型

    • **classmethod load(load_path, env=None, kwargs)

      从文件中载入模型

      参数数据类型意义
      load_pathstr or file-like文件路径
      envGym Envrionment载入模型运行的新环境(如果你只是从训练好的模型来做预测可以是None)
      kwargs载入过程中能改变模型的额外
    • load_parameters(load_path_or_dict, exact_match=True)

      从文件或字典中载入模型参数

      字典关键字是tensorflow变量名,可以用get_parameters函数获取。如果exact_matchTrue,字典应该包含所有模型参数的关键字,否则报错RunTimeError。如果是False,只有字典包含的变量会被更新。

      此函数并不载入agent的超参数

      警告:

      此函数不更新训练器/优化器的变量(例如:momentum)。因为使用此函数的这种训练可能会导致低优化结果

      参数数据类型意义
      load_path_or_dictstr or file-like保存参数或变量名字典位置->载入的是ndarrays
      exact_matchbool如果是True,期望载入关键字包含模型所有变量的字典;如果是False,只载入字典中提及的参数。默认True
    • predict(observation, state=None, mask=None, deterministic=False)

      获取从参数得到的模型行动

      参数数据类型意义
      observationnp.ndarray输入观测
      statenp.ndarray最新状态(可以时None,用于迭代策略)
      masknp.ndarray最新掩码(可以时None,用于迭代策略)
      deterministicbool是否返回确定性的行动

      返回:(np.ndarray, np.ndarray) 模型的行动和下一状态(用于迭代策略)

    • pretrain(dataset, n_epochs=10, learning_rate=0.0001, adam_epsilon=1e-08, val_interval=None)

      用行为克隆预训练一个模型:在给定专家数据集上的监督学习

      目前只支持Box和离散空间

      参数数据类型意义
      datasetExpertDataset数据集管理器
      n_epochsint训练集上的迭代次数
      learning_ratefloat学习率
      adam_epsilonfloatadam优化器的 ϵ \epsilon ϵ
      val_intervalint报告每代的训练和验证损失。默认最大纪元数的十分之一

      返回:(BaseRLModel) 预训练好的模型

    • save(save_path)

      保存当前参数到文件

      参数数据类型意义
      save_pathstr or file-like object保存位置
    • set_env(env)

      检查环境的有效性,如果是一致的,将其设置为当前环境

      参数数据类型意义
      envGym Environmentx学习一个策略的环境
    • setup_model()

      创建训练模型所必须的函数的tensorflow图表


http://www.ppmy.cn/news/880857.html

相关文章

强化学习-A2C

关于A2C的介绍可以参考书本158页 流程图 此处参考强化学习–从DQN到PPO, 流程详解 图片来源于博客强化学习之policy-based方法A2C实现(PyTorch) 代码实现 代码参考Actor-Critic-pytorch import gym, os from itertools import count impo…

REINFORCE和A2C的异同

两者的神经网络结构一模一样,都是分为两个网络,即策略神经网络和价值神经网络。但是两者的区别在于价值神经网络的作用不同,A2C中的可以评价当前状态的好坏,而REINFORCE中的只是作为一个Baseline而已,唯一作用就是降低…

Actor-Critic(A2C)算法 原理讲解+pytorch程序实现

文章目录 1 前言2 算法简介3 原理推导4 程序实现5 优缺点分析6 使用经验7 总结 1 前言 强化学习在人工智能领域中具有广泛的应用,它可以通过与环境互动来学习如何做出最佳决策。本文将介绍一种常用的强化学习算法:Actor-Critic并且附上基于pytorch实现的…

A2C算法原理及代码实现

本文主要参考王树森老师的强化学习课程 1.A2C算法原理 A2C算法是策略学习中比较经典的一个算法,是在 Barto 等人1983年提出的。我们知道策略梯度方法用策略梯度更新策略网络参数 θ,从而增大目标函数,即下面的随机梯度: Actor-C…

强化学习算法A2C(Advantage Actor-Critic)和A3C(Asynchronous Advantage Actor-Critic)算法详解以及A2C的Pytorch实现

一、策略梯度算法回顾 策略梯度(Policy Gradient)算法目标函数的梯度更新公式为: ▽ R ˉ θ 1 N ∑ n 1 N ∑ t 1 T n ( ∑ t ′ t T n γ t ′ − t r t ′ n − b ) ▽ l o g p θ ( a t n ∣ s t n ) (1) \bigtriangledown \bar{R}…

Unity 3D 脚本编程与游戏开发 学习笔记

学习笔记 内容提要Unity脚本概览控制物体移动触发器事件 Unity 基本概念与脚本编程物体、组件和对象创建物体实例——3D射击游戏 内容提要 全书从建立编程脚本和游戏框架为出发点,逐步阐述游戏开发中的核心概念,核心的物理系统和数学基础,然…

【Rust 基础篇】Rust 自定义迭代器

导言 在 Rust 中,自定义迭代器可以帮助我们根据特定需求实现符合自己逻辑的迭代过程。自定义迭代器是通过实现 Iterator trait 来完成的。本篇博客将详细介绍如何在 Rust 中自定义迭代器,包括自定义迭代器的定义、必要的方法和一些常见的使用场景。 自…

解“冰刃”的使用方法

冰刃——IceSWord是一斩断黑手的利刃 。它适用于windows 2000/XP/2003操作系统,用于查探系统中的幕后黑手(木马后门)并作出处理,当然使用它需要用户有一些操作系统的知识。  在对软件做讲解之前,首先说明第一注意事项:此程序运行…