Online Decision Transformer

news/2025/1/18 10:45:39/

摘要

  • 最近的工作表明,离线强化学习 (RL) 可以表述为序列建模问题 (Chen et al., 2021; Janner et al., 2021),并通过类似于大规模语言建模的方法来解决。 然而,RL 的任何实际实例化还涉及在线组件,其中在被动离线数据集上预训练的策略通过与环境的特定任务交互进行微调。 我们提出了在线决策Transformer(ODT),这是一种基于序列建模的 RL 算法,将离线预训练与在线微调融合在一个统一的框架中。 我们的框架使用序列级熵正则化器与自回归建模目标相结合,以实现样本有效的探索和微调。 根据经验,我们表明 ODT 在 D4RL 基准测试的绝对性能上与最先进的技术具有竞争力,但在微调过程中显示出更显着的收益。

引言

  • 序列建模的生成式预训练已成为许多领域和模式中机器学习的统一范式,特别是在语言和视觉方面(Radford 等人,2018;Chen 等人,2020;Brown 等人,2020;Lu 等人, 2022)。 最近,这种预训练范式已扩展到离线强化学习 (RL)(Chen 等人,2021 年;Janner 等人,2021 年),其中训练代理以自回归最大化离线数据集中轨迹的可能性。 在训练期间,这种范式本质上将离线 RL 转换为监督学习问题(Schmidhuber,2019;Srivastava 等人,2019;Emmons 等人,2021)。 然而,这些工作呈现出不完整的画面,因为通过离线 RL 学习的策略受到训练数据集质量的限制,需要通过在线交互对感兴趣的任务进行微调。 这种监督学习范式是否可以扩展到在线环境仍然是一个悬而未决的问题。
  • 与语言和感知不同,RL 的在线微调与预训练阶段根本不同,因为它涉及通过探索获取数据。 对探索的需求使得离线 RL 的传统监督学习目标(例如,均方误差)在在线环境中不足。 此外,据观察,对于标准在线算法,访问离线数据通常会对在线性能产生零甚至负面影响(Nair 等,2020)。 因此,离线预训练的整体流程以及对 RL 策略的在线微调需要仔细考虑训练目标和协议。
  • 我们介绍了在线决策变压器(ODT),这是一种RL的学习框架,它将离线预训练与在线微调相结合,以实现样本高效的策略优化。我们的框架建立在先前为离线RL引入的决策变换器(DT)(Chen et al,2021)架构的基础上,特别适用于在线交互成本高昂的场景,这需要离线预训练和样本有效的微调。我们确定了与DTs不兼容的几个关键缺点,并对其进行了纠正,从而为我们的整体渠道带来了卓越的性能。
  • 首先,我们从确定性政策转向随机政策,以确定在线阶段的探索目标。我们通过类似于最大RL框架的策略熵来量化探索(Levine,2018。然而,与传统框架不同的是,ODT的策略熵在轨迹的总体水平上受到限制(与单个时间步长相反),并且其双重形式规范了监督学习目标(与直接收益最大化相反)。接下来,我们开发了一种与ODT的体系结构和训练协议一致的新型重放缓冲区(Mnih等,2015)。缓冲区存储轨迹,并通过ODT的在线推出进行填充。由于ODT参数化返回条件策略,我们进一步研究在在线推出期间指定所需返回的策略。但是,此值可能与推出期间观察到的真实返回不匹配。为了应对这一挑战,我们将后验experience replay(Andrychowicz等,2017)的概念扩展到我们的设置,并在增强它们之前用正确的return tokens重新标记推出的轨迹。
  • 根据经验,我们通过将其性能与D4RL基准上的最新算法进行比较来验证我们的整体框架(Fu等,2020)。我们发现,由于我们的微调策略,我们的相对改进优于其他基线(Nair等,2020;Kostrikov等,2021a),同时在考虑基础模型的预训练结果时表现出竞争性的绝对表现。最后,我们通过严格的消融和额外的实验设计来补充我们的主要结果,以证明和验证我们方法的关键组成部分。

相关工作

  • 我们的工作包括两个广泛的研究途径,我们在这里详细介绍。
  • RL的Transformer。最近令人兴奋的进展是将离线RL问题制定为上下文条件序列建模问题(Chen等,2021;Janner等,2021)。这些工作建立在强化学习作为监督学习范式(Schmidhuber,2019;Srivastava等,2019;Emmons等,2021)的基础上,其侧重于以任务规范(例如,target goal或return)为条件的动作序列的预测建模,而不是显式学习Q函数或策略梯度。Chen等人(2021)将Transformer训练为无模型上下文条件策略,Janner等人(2021)将Transformer训练为策略和模型,并表明波束搜索可用于改进纯无模型性能。然而,这些工作仅探索离线RL设置,其类似于Transformer传统上在自然语言处理应用中训练的固定数据集。我们的工作重点是将这些结果扩展到在线微调环境,显示出与最先进的RL方法的竞争力。
  • 离线RL方法主要将conservative保守组件添加到现有的非策略RL方法中以防止分布外推,但需要对超参数进行许多调整和重新调整才能起作用(Kumar等,2020a;Kostrikov等,2021b))。与我们的工作类似,Fujimoto和Gu(2021)展示了将行为克隆术语添加到离线RL方法的好处,并且该术语的简单添加允许将非策略RL算法以最小的变化移植到离线设置。
  • 离线RL与在线微调。虽然ODT源于与传统RL方法不同的视角,但现有的许多工作都集中在对给定离线数据集进行预训练的相同范例上,并在在线环境中进行微调。Nair等人(2020)表明,将离线或非策略性RL方法应用于离线预培训和在线微调制度往往无助于甚至阻碍绩效。这种政策外方法的不良表现可归因于政策外引导错误积累(Munos,20032005;Farahmand等,2010;Kumar等,2019)。在离线RL方法中,在线微调制度中的不良表现可以通过过度保守来解释,这在离线制度中是必要的,以防止价值高估超出分配状态。Nair等人(2020)首次提出了一种适用于离线和在线培训制度的算法。最近的工作(Kostrikov et al。,2021a)也提出了一种基于期望的离线RL隐式Q学习算法,该算法也显示出强大的在线微调性能,因为该策略是通过避免分发外行为的行为克隆步骤提取的动作。
  • Lee等人(2021)通过平衡重放方案和一系列功能来解决离线在线设置问题,以在离线培训期间保持保守主义。Lu等人(2021)改进了AWAC(Nair et al。,2020),它在在线微调阶段表现出崩溃,在在线阶段纳入了积极的抽样和探索。我们还发现积极的抽样和探索是良好的在线微调的关键,但是我们将展示ODT中这些特征是如何自然发生的,从而导致一种简单的端到端方法,可以自动适应离线和在线设置。

预赛

  • 我们假设我们的环境可以建模为马尔可夫决策过程 (MDP),可以描述为 M=<S,A,p,P,R,γ>M=<S, A, p, P, R, γ>M=<S,A,p,P,R,γ>,其中 SSS 是状态空间,AAA 是动作空间,P(st+1∣st,at)P(s_{t+1}|s_t,a_t)P(st+1stat) 是转换的概率分布,R(st,at)R(s_t,a_t)R(stat) 是奖励函数,γγγ 是折扣因子(Bellman,1957)。 代理从从固定分布 p(s1)p(s_1)p(s1) 采样的初始状态 s1s_1s1 开始,然后在每个时间步 ttt 它从状态 st∈Ss_t \in SstSat∈Aa_t \in AatA 采取行动并移动到下一个状态 st+1P(⋅∣st,at)s_{t+1}~P(\cdot |s_t, a_t)st+1 P(st,at)。 在每个动作之后,代理都会收到一个确定性的奖励 rt=R(st,at)r_t=R(s_t,a_t)rt=R(st,at)。 请注意,我们的算法也直接适用于部分可观察马尔可夫决策过程 (POMDP),但我们使用 MDP 框架以便于阐述。

3.1 设置和符号

  • 我们对决策转换器 (DT) 的在线微调感兴趣(Chen 等人,2020 年),其中代理将可以访问非平稳训练数据分布 TTT。最初,在预训练期间,TTT 对应于线数据分布 TofflineT_{offline}Toffline,并通过离线数据集 TofflineT_{offline}Toffline 访问。 在微调期间,它通过重播缓冲区 TreplayT_{replay}Treplay 访问。 令 τττ 表示轨迹并令 ∣τ∣|τ |τ 表示它的长度。 轨迹 τττ 在时间步长 ttt 的返回 (RTG),gt “ř|τ|t1“trt1 ,是该时间步长的未来奖励总和。 让“ pa1, . . . , a|τ|q, s “ ps1, . . . , s|τ|q 和 g “ pg1, . . . , g|τ|q 分别表示 τ 的动作序列、状态和 RTG。

Online Decision Transformer

  • 由于训练数据的限制,在纯离线数据集上训练的 RL 策略通常不是最优的,因为离线轨迹可能没有高回报并且仅覆盖状态空间的有限部分。 提高性能的一种自然策略是通过在线交互微调预训练的 RL 代理。 然而,标准决策转换器的学习公式不足以进行在线学习,正如我们将在实验消融中展示的那样,当天真地用于在线数据采集时会崩溃。 在本节中,我们介绍了对决策转换器的关键修改,以实现高效采样的在线微调。
  • 作为第一步,我们提出了一个广义的概率学习目标。 我们将扩展此公式以解释在线决策转换器 (ODT) 中的探索。 在概率设置中,我们的目标是学习最大化数据集可能性的随机策略。 例如,对于连续动作空间,我们可以使用具有对角线的多元高斯分布的标准选择(Haarnoja 等人,2018a;Fujimoto & Gu,2021;Kumar 等人,2020b;Emmons 等人,2021) 协方差矩阵,用于模拟以状态和 RTG 为条件的动作分布。 让 θθθ 表示策略参数。 正式地,我们的政策是
    在这里插入图片描述
  • 其中协方差矩阵 ΣθΣ_θΣθ 假定为对角矩阵。 给定随机策略,我们最大化训练数据集中轨迹的对数似然,或等效地最小化负对数似然 (NLL) 损失。
    在这里插入图片描述
  • 我们在这里考虑的策略包含 DT 考虑的确定性策略。 优化目标 (1) 等同于优化 (3),假设协方差矩阵 Σθ 是对角矩阵并且方差在所有维度上都相同,这是我们假设涵盖的特例。

4.1 最大熵序列建模

  • 在线 RL 算法的关键属性是能够平衡探索-开发权衡。 即使使用随机策略,如等式 (3) 中的传统 DT 公式也没有考虑探索。 为了解决这个缺点,我们首先通过定义为的策略熵来量化探索:
    在这里插入图片描述
  • 其中 H[πθ(ak)]H[π_θ(a_k)]H[πθ(ak)] 表示分布 πθ(ak)π_θ(a_k)πθ(ak) 的香农熵。 策略熵取决于数据分布 TTT,它在离线预训练阶段是静态的,但在微调期间是动态的,因为它取决于探索期间获得的在线数据。
  • 类似于许多现有的 max-ent RL 算法 (Levine, 2018),例如 Soft Actor Critic (SAC, Haarnoja et al. (2018a;b)),我们明确地对策略熵施加一个下限以鼓励探索。 也就是说,我们有兴趣解决以下约束问题:
    在这里插入图片描述
  • 其中 β 是一个前缀超参数。 继 Haarnoja 等人 (2018b) 之后,在实践中,我们解决了 (5) 的对偶问题,以避免显式处理不等式约束。 即,我们考虑拉格朗日 Lpθ, λq “ Jpθq `λpβ ´HTθra|s, gsq 并通过交替优化 θ 和 λ 来解决问题 maxλě0minθLpθ, λq。 用固定的 λ 优化 θ 等同于
  • 最后,我们就实际优化细节方面与 SAC 的相似性发表了几点评论。 首先,我们没有完全解决子问题(6)和(7)。 对于它们两者,我们每次只进行一次梯度更新,也就是一步交替梯度下降。 其次,HTθ ra|s, gs 的计算涉及积分。 我们使用单样本蒙特卡洛估计来近似每个积分,并且样本被重新参数化以进行低方差梯度计算。 正如 Haarnoja 等人 (2018b) 也指出的那样,我们经常观察到问题 (5) 中的约束很紧,因此实际熵 HTθra|s, gs 与 β 匹配。

Training Pipeline

  • 我们使用变压器架构实例化上述公式。 在线决策转换器 (ODT) 建立在 DT 架构之上,并包含由于随机策略而产生的变化。 我们通过输出端的两个独立的全连接层来预测策略均值和对数方差。 算法 1 总结了 ODT 中的整体微调管道,其中详细的内部训练步骤在算法 2 中进行了描述。我们在下面概述了这些算法的主要组成部分,并在附录 B 中讨论了其他设计选择和超参数。
  • 轨迹级回放缓冲器。我们使用重放缓冲区来记录我们过去的经历并定期更新。对于大多数现有的 RL 算法,重放缓冲区由转换组成。在 rollout 中的每一步在线交互之后,策略或 Q 函数都会通过梯度步骤进行更新,并执行策略以收集新的转换以添加到重放缓冲区中。然而,对于 ODT,我们的回放缓冲区包含轨迹而不是转换。离线预训练后,我们通过离线数据集中回报率最高的轨迹初始化回放缓冲区。每次我们与环境交互时,我们都会使用当前策略完全推出一个情节,然后以先进先出的方式使用收集到的轨迹刷新重播缓冲区。之后,我们再次更新策略并推出,如算法 1 所示。与 Haarnoja 等人 (2018a) 类似,我们还观察到使用平均动作评估策略通常会带来更高的回报,但使用采样更有好处在线探索的行动,因为它会产生更多样化的数据。
  • 事后回报重新贴标签。 Hindsight experience replay (HER) 是一种在奖励稀疏的环境中提高目标条件代理的样本效率的方法(Andrychowicz 等人,2017 年;Rauber 等人,2017 年;Ghosh 等人,2019 年) . 这里的关键思想是将智能体的轨迹重新标记为已实现的目标,而不是预期目标。 对于 ODT,我们正在学习以初始 RTG 为条件的策略。 然而,在政策推出期间获得的回报和诱导的 RTG 可能与预期的 RTG 不同。 受 HER 的启发,我们用实现的回报为展开的轨迹 τττ 重新标记 RTG 代币,这样在最后一个时间步 g∣τ∣g_{|\tau|}gτ 的 RTG 代币恰好是代理 r∣τ∣r_{|τ|}rτ 获得的奖励,参见算法 2 的第 6 行。 这种返回重新标记策略适用于奖励稀疏和密集的环境。
  • RTG 调节。 ODT 需要一个超参数,初始 RTG gonlineg_{online}gonline,用于收集额外的在线数据(参见算法 1 的第 4 行)。 此前,Chen 等人 (2021) 表明,离线 DT 的实际评估回报与经验上的初始 RTG 具有很强的相关性,并且通常可以推断出超过离线数据集中观察到的最大回报的 RTG 值。 对于 ODT,我们发现最好将此超参数设置为专家回报的一个小的、固定的比例(在我们的实验中设置为 2)。 我们还试验了更大的值以及随时间变化的课程(例如,离线和在线数据集中最佳评估回报的分位数),但我们发现这些相对于固定的、缩放的 RTG 而言是次优的。
  • 抽样策略。 与 DT 类似,算法 2 使用两步采样过程来确保重放缓冲区 Treplay 中长度为 K 的子轨迹被均匀采样。 我们首先以与其长度成正比的概率采样单个轨迹,然后统一采样长度为 K 的子轨迹。对于具有非负密集奖励的环境,我们的采样策略类似于重要性采样。 在这些情况下,轨迹的长度与其返回高度相关,正如我们在附录 F 中进一步强调的那样。

动态训练

  • 我们评论了一些关于 ODT 训练动态及其影响的经验观察。 我们首先展示一个示例运行,其中 ODT 的在线返回饱和,表明训练已经收敛。 我们将自己限制在算法 1 收敛的情况下,讨论 ODT 的训练动态。 这种收敛假设使我们能够分析学习目标在训练过程中的含义,以及初始 RTG 令牌对 ODT 策略的行为变化。 我们强调算法 1 的收敛保证是一个悬而未决的问题,超出了本文的范围,我们的主张将主要通过实验来指导。

http://www.ppmy.cn/news/156.html

相关文章

如何翻译英文音频?看完你就学会了

在平时的工作中&#xff0c;相信大家应该都会遇到一些不太熟悉的英文或者其它外文的语言&#xff0c;这给我们的生活带来了诸多烦恼&#xff0c;那遇到这种情况&#xff0c;我们应该怎么办呢&#xff1f;其实很简单&#xff0c;我们可以利用一些软件来将这些语言转换成中文&…

Java本地搭建实战毕设项目sprignboot电商书城管理系统源码

大家好啊&#xff0c;我是测评君&#xff0c;欢迎来到web测评。 本期给大家带来一套Java开发的sprignboot电商书城管理系统源码&#xff0c;包含前端界面、后台管理界面。适合拿来做毕业设计的同学。可以下载来研究学习一下。本期就把这套系统分享给大家。 技术架构 技术框架&…

C语言 深度探究C语言中的多字节字符

多字节字符 本章介绍 C 语言如何处理非英语字符。 Unicode 简介 C 语言诞生时&#xff0c;只考虑了英语字符&#xff0c;使用7位的 ASCII 码表示所有字符。ASCII 码的范围是0到127&#xff0c;也就是最多只能表示100多个字符&#xff0c;用一个字节就可以表示&#xff0c;所…

卡尔曼滤波与融合算法

不要被复杂公式吓到&#xff0c;按下面的步骤一步一步来&#xff0c;每个概念都学清楚&#xff0c;卡尔曼并不难理解学习卡尔曼&#xff0c;需要先了解几个基础知识 测不准定律&#xff1a;比如说我们要测量一个电压&#xff0c;需要借助传感器&#xff0c;但是传感器无法给出真…

年薪50W的数字前端设计工程师是做什么的?

近两年&#xff0c;芯片行业大火&#xff0c;行业的发展受到了很大的政策支持&#xff0c;芯片行业不仅发展前景好&#xff0c;薪资待遇也很高&#xff0c;所以不少人纷纷转行IC&#xff0c;那么转行IC岗位该如何选择呢&#xff1f;下面IC修真院就重点为大家来介绍一下数字前端…

「PAT乙级真题解析」Basic Level 1098 岩洞施工 (问题分析+完整步骤+伪代码描述+提交通过代码)

乙级的题目训练主要用来熟悉编程语言的语法和形成良好的编码习惯和编码规范。从小白开始逐步掌握用编程解决问题。 PAT (Basic Level) Practice 1098 岩洞施工 问题分析 题设给定了岩洞中每一个位置的顶部高度和底部高度, 要求判断是否能够将一个单位的长管道水平送入岩洞中。…

计算机毕业设计springboot紧急自救知识教学与交流平台9c75u源码+系统+程序+lw文档+部署

计算机毕业设计springboot紧急自救知识教学与交流平台9c75u源码系统程序lw文档部署 计算机毕业设计springboot紧急自救知识教学与交流平台9c75u源码系统程序lw文档部署本源码技术栈&#xff1a; 项目架构&#xff1a;B/S架构 开发语言&#xff1a;Java语言 开发软件&#xf…

springboot+jsp学生心理健康测评网

基于JSP技术设计并实现了学生心理健康网。该系统基于B/S即所谓浏览器/服务器模式&#xff0c;应用SSM框架&#xff0c;选择MySQL作为后台数据库。系统主要包括个人中心、用户管理、知识分类管理、知识信息管理、心理测试管理、交流论坛、试卷管理、系统管理、考试管理等功能模块…