ML-Agents:训练配置文件(一)

embedded/2024/12/26 18:01:08/

注:本文章为官方文档翻译,如有侵权行为请联系作者删除
Training Configuration File - Unity ML-Agents Toolkit–原文链接

ML-Agents:训练配置文件(一)
ML-Agents:训练配置文件(二)

常见训练器配置

关于训练,您需要做出的首要决定之一是使用哪种训练器:PPO、SAC 还是 POCA。有些训练配置是两种训练器都通用的(我们现在将对此进行回顾),而其他训练配置则取决于训练器的选择(我们将在后续章节中进行回顾)。

环境描述
trainer_type(默认值 = ppo)要使用的训练器类型:pposacpoca
summary_freq(默认值 = 50000)在生成和显示训练统计数据之前需要收集的经验数。这决定了 Tensorboard 中图表的粒度。
time_horizon(默认值 = 64)在将每个Agent的经验添加到经验缓冲区之前,需要收集多少步经验。当在情节结束前达到此限制时,将使用价值估计来预测Agent当前状态的总体预期奖励。因此,此参数在偏差较小但方差较大的估计(长期范围)和偏差较大但变化较小的估计(短期范围)之间进行权衡。如果情节中奖励频繁,或者情节过大,则较小的数字可能更为理想。这个数字应该足够大,以捕捉Agent动作序列中的所有重要行为。

典型范围:32-2048
max_steps(默认值 = 500000)在结束训练过程之前必须在环境中(或如果并行使用多个,则在所有环境中)采取的总步骤数(即收集的观察结果和采取的行动)。如果您的环境中有多个具有相同行为名称的Agent,则这些Agent采取的所有步骤都将计入相同的max_steps计数。

典型范围:5e5-1e7
keep_checkpoints(默认值 = 5)要保留的模型检查点的最大数量。检查点会在由 checkpoint_interval 选项指定的步骤数之后保存。一旦达到最大检查点数量,则在保存新检查点时会删除最旧的检查点。
even_checkpoints(默认值 = false)如果设置为 true,则忽略checkpoint_interval并根据keep_checkpointsmax_steps在整个训练过程中均匀分布检查点,即checkpoint_interval = max_steps / keep_checkpoints。在训练过程中记录Agent行为时非常有用。
checkpoint_interval(默认值 = 500000)训练器在每个检查点之间收集的经验数。keep_checkpoints在删除旧检查点之前,最多可保存检查点数。每个检查点都会将.onnx文件保存在results/文件夹中。
init_path(默认值 = None)从之前保存的模型初始化训练器。请注意,之前的运行应该使用与当前运行相同的训练器配置,并使用相同版本的 ML-Agents 保存。

您可以提供文件名或检查点的完整路径,例如{checkpoint_name.pt}./models/{run-id}/{behavior_name}/{checkpoint_name.pt}。如果您想从不同的运行初始化不同的行为或从较旧的检查点初始化,则提供此选项;在大多数情况下,使用--initialize-fromCLI 参数从同一运行初始化所有模型就足够了。
threaded(默认值 = false)允许环境在更新模型时进行迭代。这可能导致训练速度加快,尤其是在使用SAC的情况下。在使用自我博弈时,为了获得最佳性能,请将此设置设置为 false
hyperparameters -> learning_rate(默认值 = 3e-4)梯度下降的初始学习率。对应于每个梯度下降更新步骤的强度。如果训练不稳定且奖励未持续增加,通常应将此值减小。

典型范围:1e-5-1e-3
hyperparameters -> batch_size每次迭代中使用的经验数。这应该始终比buffer_size小几倍。如果您使用的是连续操作,这个值应该较大(以1000为数量级)。如果您只使用离散的操作,则该值应该较小(以10为数量级)。

典型范围:(连续 - PPO):512- 5120;(连续 - SAC):128- 1024;(离散、PPO 和 SAC):32- 512
hyperparameters -> buffer_size(PPO默认为10240, SAC默认为50000

PPO:在更新策略模型之前要收集的经验数。对应于我们在学习或更新模型之前应该收集多少经验。
这个值应该batch_size大好几倍。通常,较大的buffer_size对应于更稳定的训练更新。

SAC :经验缓冲区的最大大小 - 比您的情节长数千倍,以便 SAC 可以从旧经验和新经验中学习。

典型范围:PPO:2048- 409600;SAC:50000-1000000
hyperparameters -> learning_rate_schedule(默认值 = PPO:linear ;SAC:constant)决定学习率如何随时间变化。对于PPO,我们建议将学习率逐渐降低至max_steps,以便学习更稳定地收敛。然而,对于某些情况(例如训练时间不确定),可以禁用此功能。对于SAC,我们建议保持学习率不变,以便代理可以在其Q函数自然收敛之前继续学习。

linear线性衰减学习率,在 max_steps 处达到 0,同时constant在整个训练过程中保持学习率不变。
network_settings -> hidden_units(默认值 = 128神经网络隐藏层中的单元数。对应于神经网络每个全连接层中的单元数。对于简单问题,其中正确的操作是观察输入的直接组合,这个值应该较小。对于问题复杂,动作是观察变量之间复杂交互的情况,这个值应该较大。

典型范围:32-512
network_settings -> num_layers(默认值 = 2神经网络中的隐藏层数。对应于在观察输入之后或在视觉观察的CNN编码之后存在的隐藏层数。对于简单问题,较少的隐藏层可能训练得更快、更高效。对于更复杂的控制问题,可能需要更多的隐藏层。

典型范围:1-3
network_settings -> normalize(默认值 = false)是否对向量观测输入应用归一化。这种归一化基于向量观测的移动平均值和方差。归一化在具有复杂连续控制问题的情况下可能有所帮助,但在具有更简单的离散控制问题的情况下可能有害。
network_settings -> vis_encode_type(默认 = simple) 用于对视觉观察进行编码的编码器类型。

simple(默认)使用由两个卷积层组成的简单编码器。
nature_cnn使用Mnih 等人提出的由三个卷积层组成的CNN 实现。
resnet使用由三层堆叠结构组成IMPALA Resnet,每层包含两个残差块,因此其网络结构比其他两个要大得多。
match3是一个较小的 CNN(Gudmundsoon 等人),可以捕捉更细微的空间关系,并针对棋盘游戏进行了优化。
fully_connected 使用一个没有卷积层的单个全连接稠密层作为编码器。

由于卷积核的大小,每种类型的编码器都有最小观察大小限制 - simple : 20x20, nature_cnn : 36x36, resnet : 15 x 15, match3 : 5x5. fully_connected 没有卷积层,因此没有大小限制,但因为它的表示能力较弱,因此应该仅用于非常小的输入。请注意,使用 match3 的 CNN 处理非常大的视觉输入可能会导致观察编码非常大,从而可能减慢训练速度或导致内存问题。
network_settings -> conditioning_type(默认值 = hyper)使用目标观测值的策略条件类型。

它将目标观测视为常规观测,(默认情况下)使用带有目标观测作为输入的HyperNetwork来生成一些策略的权重。请注意,当使用 hyper 时,网络的参数数量会大大增加。因此,建议在使用此 conditioning_type时减少 hidden_units 的数量。

鉴于作者水平有限,本文可能存在不足之处,欢迎各位读者提出指导和建议,共同探讨、共同进步。


http://www.ppmy.cn/embedded/148956.html

相关文章

对 MYSQL 架构的了解

MySQL 是一种广泛使用的关系型数据库管理系统,其架构主要包括以下几个关键部分: 一、连接层 客户端连接管理:MySQL 服务器可以同时处理多个客户端的连接请求。当客户端应用程序(如使用 Java、Python 等语言编写的程序)…

Llama 3 模型系列解析(一)

目录 1. 引言 1.1 Llama 3 的简介 1.2 性能评估 1.3 开源计划 1.4 多模态扩展 ps 1. 缩放法则 2. 超额训练(Over-training) 3. 计算训练预算 4. 如何逐步估算和确定最优模型? 2. 概述 2.1 Llama 3 语言模型开发两个主要阶段 2.2…

云手机+YouTube:改变通信世界的划时代技术

随着科技的不断进步,手机作为人们生活中不可或缺的工具,也在不断地更新换代。近年来,一个名为“油管云手机”的全新产品正在引起广泛的关注和讨论。作为一个运用最新科技实现的新型手机,它在通信领域带来了全新的体验和革命性的变…

从汽车企业案例看仓网规划的关键步骤(视频版)

大家好,欢迎收看本期视频。本期内容将以汽车行业为例,带您了解仓库选址和仓网规划如何解决实际问题,以及需要注意的关键步骤。 案例描述 国内某大型汽车企业目前在全国范围内拥有十多个生产厂地和近千家供应商。这些供应商分布在多个地理区…

errant是怎么产生的

目录 1.产生errant GTID的原因2.检查errant GTID3.处理errant GTID方式一 忽略errant GTID方式二 重置从库方式三 注入空事务 在MySQL中,errant GTID(错误GTID)是指在从库上存在但在主库上不存在的GTID。 这通常是由于在从库上执行了不应存在…

Gemini 2.0:面向智能体时代的全新 AI 模型

2024年12月11日,Google 发布了 Gemini 2.0 系列的首个模型——Gemini 2.0 Flash(实验版)。凭借多模态方面的新进展以及原生工具的使用,Gemini 2.0 Flash (实验版) 能够构建新的 AI 智能体,推动了实现通用 AI 助手愿景的…

界面化管理Nginx的工具—NginxUI简介与搭建

转载说明:如果您喜欢这篇文章并打算转载它,请私信作者取得授权。感谢您喜爱本文,请文明转载,谢谢。 1. NginxUI简介 1.1 NginxUI介绍 Nginx UI 是一个全新的 Nginx 网络管理界面,旨在简化 Nginx 服务器的管理和配置。…

UI自动化测试实战实例

🍅 点击文末小卡片,免费获取软件测试全套资料,资料在手,涨薪更快 今天来说说pytest吧,经过几周的时间学习,有收获也有疑惑,总之最后还是搞个小项目出来证明自己的努力不没有白费。 环境准备 1…