InstructGPT的四阶段:预训练、有监督微调、奖励建模、强化学习涉及到的公式解读

server/2024/10/16 0:25:55/

1. 预训练

在这里插入图片描述

1. 语言建模目标函数(公式1):

L 1 ( U ) = ∑ i log ⁡ P ( u i ∣ u i − k , … , u i − 1 ; Θ ) L_1(\mathcal{U}) = \sum_{i} \log P(u_i \mid u_{i-k}, \dots, u_{i-1}; \Theta) L1(U)=ilogP(uiuik,,ui1;Θ)

  • 解释
    • U = { u 1 , u 2 , … , u n } \mathcal{U} = \{u_1, u_2, \dots, u_n\} U={u1,u2,,un}是输入的未标注语料(token序列)。
    • k k k是上下文窗口的大小,即预测当前词 u i u_i ui时,使用前 k k k个词( u i − k , … , u i − 1 u_{i-k}, \dots, u_{i-1} uik,,ui1)作为上下文。
    • P ( u i ∣ u i − k , … , u i − 1 ; Θ ) P(u_i \mid u_{i-k}, \dots, u_{i-1}; \Theta) P(uiuik,,ui1;Θ) 是模型根据前 k k k个词预测 u i u_i ui的条件概率,其中参数 Θ \Theta Θ是通过训练得到的神经网络参数。
    • 通过最大化对数似然,模型被训练以最小化预测和真实词之间的差距,这个过程通常通过随机梯度下降(SGD)进行。

2. Transformer解码器结构(公式2):

公式2描述了模型的架构,采用了多层的Transformer解码器。Transformer通过自注意力机制来捕捉上下文依赖关系,并对输入序列进行编码。

初始嵌入层:

h 0 = U W e + W p h_0 = U W_e + W_p h0=UWe+Wp

  • 解释
    • U = ( u − k , … , u − 1 ) U = (u_{-k}, \dots, u_{-1}) U=(uk,,u1) 是上下文窗口中输入序列的词向量。
    • W e W_e We是词嵌入矩阵,用于将输入的token转换为词向量。
    • W p W_p Wp是位置嵌入矩阵,提供每个token的位置编码,用于捕捉词序信息。
Transformer块:

h l = transformer_block ( h l − 1 ) ∀ i ∈ [ 1 , n ] h_l = \text{transformer\_block}(h_{l-1}) \quad \forall i \in [1, n] hl=transformer_block(hl1)i[1,n]

  • 解释
    • h l h_l hl表示第 l l l层的Transformer输出。
    • 每一层 h l h_l hl是通过前一层 h l − 1 h_{l-1} hl1经过Transformer块(自注意力和前馈网络)的处理得到的。
    • 共有 n n n层,每一层都通过类似的操作进行。
输出层:

P ( u ) = softmax ( h n W e T ) P(u) = \text{softmax}(h_n W_e^T) P(u)=softmax(hnWeT)

  • 解释
    • h n h_n hn是最后一层的输出,经过词嵌入矩阵的转置 W e T W_e^T WeT变换后,再通过Softmax函数计算每个词的概率分布。
    • 这个概率分布用于预测输出的目标词 u u u,Softmax确保输出的各个词的概率和为1。

总结:

  • 该模型采用无监督的方式进行预训练,利用大规模未标注语料数据,通过最大化词序列的条件概率来训练语言模型。
  • 预训练的模型架构基于Transformer解码器,通过多层自注意力机制和位置编码来有效捕捉上下文信息,并使用Softmax输出目标词的概率分布。

2. 有监督微调

在这里插入图片描述

1. 微调任务的目标函数(公式3):

P ( y ∣ x 1 , … , x m ) = softmax ( h l m W y ) P(y \mid x^1, \dots, x^m) = \text{softmax}(h_l^m W_y) P(yx1,,xm)=softmax(hlmWy)

  • 解释
    • x 1 , … , x m x^1, \dots, x^m x1,,xm 是输入的token序列。
    • h l m h_l^m hlm表示输入序列经过预训练模型(如Transformer)的最后一层输出的激活值(即特征表示)。
    • W y W_y Wy是用于预测目标标签 y y y的线性层的参数矩阵。
    • Softmax 函数将线性层的输出转换为每个类的概率分布,用于分类任务中的标签预测。

2. 最大化目标函数(公式4):

L 2 ( C ) = ∑ ( x , y ) log ⁡ P ( y ∣ x 1 , … , x m ) L_2(C) = \sum_{(x, y)} \log P(y \mid x^1, \dots, x^m) L2(C)=(x,y)logP(yx1,,xm)

  • 解释
    • 这是监督学习的目标函数,模型通过最大化预测标签 y y y 的对数概率来微调模型参数。
    • C C C是标注数据集,包含输入序列 x x x和相应的标签 y y y
    • 目标是最大化所有样本的对数似然,确保模型在监督任务中的准确性。

3. 辅助目标函数(公式5):

L 3 ( C ) = L 2 ( C ) + λ ∗ L 1 ( C ) L_3(C) = L_2(C) + \lambda \ast L_1(C) L3(C)=L2(C)+λL1(C)

  • 解释
    • 为了提高监督学习的效果,模型还结合了语言建模的辅助目标,即无监督的语言建模损失( L 1 L_1 L1 )和监督任务的损失( L 2 L_2 L2)相结合。
    • λ \lambda λ 是用于平衡两个目标的权重参数。
    • 这样做的好处是:可以通过语言模型的任务帮助监督任务更好地泛化,同时加快收敛速度。这种辅助目标在之前的研究中已证明可以有效提高性能。

总结:

  • 监督微调阶段,模型利用预训练好的参数,结合带标签的数据来优化预测性能。
  • 通过最大化预测标签 y y y的对数概率,模型适应特定任务。
  • 引入语言建模作为辅助任务,有助于提升模型的泛化能力和训练效率。

3. 奖励建模

在这里插入图片描述

这个段落介绍了 奖励模型(Reward Modeling, RM) 在 InstructGPT 模型中的训练方式,具体描述了模型如何从 监督微调模型(SFT) 继续优化以输出奖励分数(reward score),并通过 比较(comparison)来训练这个奖励模型。

核心内容解释:

  1. 奖励模型的基础

    • 奖励模型的训练是基于监督微调模型(SFT)进一步改进的。为了输出奖励,SFT 模型的最后一层(unembedding layer)被移除,留下的是可以根据给定的提示(prompt)和回应(response)输出一个标量的奖励值。
    • 为了节省计算资源,他们使用了一个 6B 参数的奖励模型(RM),因为较大规模的 175B 参数奖励模型虽然理论上可能更准确,但实际训练时表现不稳定,且不适合作为 RL 的值函数。
  2. Stiennon et al. (2020) 的方法

    • 数据集:奖励模型通过比较训练,数据集包含两个模型生成的输出(response),并根据这些生成的结果做出比较。
    • 损失函数:训练中使用了 交叉熵损失(cross-entropy loss),这些比较的标签是人类标注员根据两者的优劣给出的。交叉熵损失衡量了两个结果中哪一个应该被人类标注员优先选择,实际优化的是两者奖励分数的对数几率(log odds)。
  3. 加速比较收集过程

    • 在标注过程中,给标注员展示了 K = 4 到 9 个不同的生成回应,标注员需要对这些生成的结果进行排名。这会生成 ( K 2 ) \binom{K}{2} (2K)对比较,即每个标注任务中有 K 个回应时,将产生 K 中选取 2 个的组合数的比较对数。比如 K = 9 K = 9 K=9 时,会生成 ( 9 2 ) = 36 \binom{9}{2} = 36 (29)=36 对比较。
    • 为了避免过度拟合,研究者决定不将所有比较对一起训练(因为不同的比较对之间存在强相关性),而是仅从每个提示中抽取一对比较结果作为一个训练样本。这种方式更加 计算高效,因为对于每个完成的任务只需要一次前向传播(forward pass),而不是处理所有 ( K 2 ) \binom{K}{2} (2K) 的比较对。
  4. 奖励模型的损失函数

    奖励模型的损失函数定义如下:
    loss ( θ ) = − 1 ( K 2 ) E ( x , y w , y l ) ∼ D [ log ⁡ ( σ ( r θ ( x , y w ) − r θ ( x , y l ) ) ) ] \text{loss}(\theta) = -\frac{1}{\binom{K}{2}} \mathbb{E}_{(x, y_w, y_l) \sim D} \left[ \log \left( \sigma \left( r_\theta(x, y_w) - r_\theta(x, y_l) \right) \right) \right] loss(θ)=(2K)1E(x,yw,yl)D[log(σ(rθ(x,yw)rθ(x,yl)))]

    • σ ( ⋅ ) \sigma(\cdot) σ()sigmoid 函数,用于将差值映射到 [0, 1] 区间,用来表示某一结果被标注员认为更优的概率。
    • r θ ( x , y ) r_\theta(x, y) rθ(x,y) 是奖励模型对于给定提示 x x x和生成的回应 y y y所输出的奖励分数。
    • y w y_w yw y l y_l yl分别是优胜和劣胜的生成结果。
    • D D D是人类标注员比较的结果数据集。
    • 1 ( K 2 ) \frac{1}{\binom{K}{2}} (2K)1 是对所有比较的标准化处理。

解释

  • 损失函数的目标是最大化优胜结果 y w y_w yw 比劣胜结果 y l y_l yl更受偏好的概率(通过两者奖励分数差异的 sigmoid 值来实现)。这个函数通过最小化损失来优化奖励模型,使得奖励模型能够更准确地给出与人类标注员偏好一致的分数。
  1. 防止过拟合的解释(脚注 5)
    • 如果每一个比较对都被视为一个单独的数据点,那么每个生成的回应可能在训练中会得到 K − 1 K-1 K1 次更新,从而导致模型过拟合。而研究人员发现,模型过度训练甚至只需一个 epoch 就会过拟合。为了解决这个问题,他们只对每个提示下的一对回应进行一次前向传播训练,从而避免过拟合。

总结:

  • 奖励模型的训练基于人类反馈,通过比较两个模型生成的回应来进行优化。该训练过程使用了 交叉熵损失函数,优化目标是让奖励模型尽可能地预测出哪个回应更符合人类标注员的偏好。
  • 通过只选取部分比较对进行训练(而不是所有组合对),减少了计算开销,并有效避免了模型过拟合。

4. 强化学习(PPO)

在这里插入图片描述

1. 强化学习(Reinforcement Learning, RL)

在这部分,模型使用了强化学习(RL)进行微调,采用了 PPO(Proximal Policy Optimization) 算法来优化策略。PPO 是一种策略梯度算法,常用于强化学习任务中,通过限制策略更新的步长来提高训练的稳定性。

2. 环境设置

强化学习的环境被设置为一个 多臂老虎机问题(bandit environment),该环境会随机给定提示(prompt),模型需要生成相应的回应。生成的回应会通过奖励模型(reward model)来打分并结束当前回合。

为了防止模型过度优化奖励模型,训练过程中在每个 token 的输出时,加入了 KL 惩罚项,这项惩罚的来源是监督微调模型(SFT, Supervised Fine-Tuned Model)。

3. PPO-ptx 模型

为了提高模型的泛化能力,研究者还尝试将 预训练梯度(pretraining gradients) 与 PPO 的梯度混合,构建了所谓的 PPO-ptx 模型。这种方法可以解决在某些公共 NLP 数据集上性能回退的问题。

他们使用了以下的 目标函数(Objective Function)

objective ( ϕ ) = E ( x , y ) ∼ D π ϕ R L [ r θ ( x , y ) − β log ⁡ ( π ϕ R L ( y ∣ x ) π S F T ( y ∣ x ) ) ] + γ E x ∼ D pretrain [ log ⁡ ( π ϕ R L ( x ) ) ] \text{objective}(\phi) = \mathbb{E}_{(x,y) \sim D_{\pi^{RL}_{\phi}}} \left[ r_\theta(x, y) - \beta \log \left( \frac{\pi^{RL}_{\phi}(y \mid x)}{\pi^{SFT}(y \mid x)} \right) \right] + \gamma \mathbb{E}_{x \sim D_{\text{pretrain}}} \left[ \log(\pi^{RL}_{\phi}(x)) \right] objective(ϕ)=E(x,y)DπϕRL[rθ(x,y)βlog(πSFT(yx)πϕRL(yx))]+γExDpretrain[log(πϕRL(x))]

4. 公式解读

第一部分:PPO 的核心目标

E ( x , y ) ∼ D π ϕ R L [ r θ ( x , y ) − β log ⁡ ( π ϕ R L ( y ∣ x ) π S F T ( y ∣ x ) ) ] \mathbb{E}_{(x,y) \sim D_{\pi^{RL}_{\phi}}} \left[ r_\theta(x, y) - \beta \log \left( \frac{\pi^{RL}_{\phi}(y \mid x)}{\pi^{SFT}(y \mid x)} \right) \right] E(x,y)DπϕRL[rθ(x,y)βlog(πSFT(yx)πϕRL(yx))]

  • E ( x , y ) ∼ D π ϕ R L \mathbb{E}_{(x,y) \sim D_{\pi^{RL}_{\phi}}} E(x,y)DπϕRL表示在 RL 策略 π ϕ R L \pi^{RL}_{\phi} πϕRL 生成的分布 D π ϕ R L D_{\pi^{RL}_{\phi}} DπϕRL上的期望。
  • r θ ( x , y ) r_\theta(x, y) rθ(x,y) 是奖励模型 r θ r_\theta rθ对生成的响应 y y y的奖励。
  • β log ⁡ ( π ϕ R L ( y ∣ x ) π S F T ( y ∣ x ) ) \beta \log \left( \frac{\pi^{RL}_{\phi}(y \mid x)}{\pi^{SFT}(y \mid x)} \right) βlog(πSFT(yx)πϕRL(yx))是 KL 散度惩罚项,惩罚 RL 策略 π ϕ R L \pi^{RL}_{\phi} πϕRL偏离 SFT 模型 π S F T \pi^{SFT} πSFT的程度,其中 β \beta β控制 KL 惩罚的权重。

解释
这部分的目标是最大化模型的奖励,同时通过 KL 惩罚防止策略 π ϕ R L \pi^{RL}_{\phi} πϕRL与监督微调模型 π S F T \pi^{SFT} πSFT偏离过远。惩罚项确保模型在强化学习时不走偏,保持与原本训练目标的相似性。

第二部分:预训练损失

γ E x ∼ D pretrain [ log ⁡ ( π ϕ R L ( x ) ) ] \gamma \mathbb{E}_{x \sim D_{\text{pretrain}}} \left[ \log(\pi^{RL}_{\phi}(x)) \right] γExDpretrain[log(πϕRL(x))]

  • E x ∼ D pretrain \mathbb{E}_{x \sim D_{\text{pretrain}}} ExDpretrain 表示在预训练数据分布 D pretrain D_{\text{pretrain}} Dpretrain上的期望。
  • log ⁡ ( π ϕ R L ( x ) ) \log(\pi^{RL}_{\phi}(x)) log(πϕRL(x)) 是 RL 策略 π ϕ R L \pi^{RL}_{\phi} πϕRL生成的结果的对数概率。
  • γ \gamma γ是预训练损失项的权重,控制预训练数据与强化学习的结合程度。

解释
这一部分引入了预训练的损失,使得模型能够保持在大规模预训练数据上的表现,防止模型在强化学习过程中完全依赖于奖励模型而失去通用能力。通过设置 γ \gamma γ,我们可以平衡预训练损失和强化学习损失的影响。

5. 策略和符号说明

  • π ϕ R L \pi^{RL}_{\phi} πϕRL 是在强化学习中学习到的策略,参数为 ϕ \phi ϕ
  • π S F T \pi^{SFT} πSFT 是通过监督学习微调得到的策略,它代表了模型在强化学习之前的性能。
  • β \beta β是 KL 惩罚的权重系数,控制 RL 策略和 SFT 策略的偏离程度。
  • γ \gamma γ 是预训练损失的权重系数,控制预训练梯度在 PPO 优化中的作用。

总结:

在这篇文章中,InstructGPT 使用了强化学习中的 PPO(Proximal Policy Optimization) 进行策略优化,同时通过引入 KL 散度惩罚项 来确保 RL 策略与 SFT 策略不过度偏离。此外,预训练损失 通过一个额外的项加入到了目标函数中,以解决在某些 NLP 任务上的性能回退问题。


http://www.ppmy.cn/server/132447.html

相关文章

ThingsBoard规则链节点:JSON Path节点详解

引言 JSON Path节点简介 用法 含义 应用场景 实际项目运用示例 智能农业监控系统 工业自动化生产线 车联网平台 结论 引言 ThingsBoard是一个功能强大的物联网平台,它提供了设备管理、数据收集与处理以及实时监控等核心功能。其规则引擎允许用户通过创建复…

Redis的应用以及Redis工具类的封装

在前后端分离的项目中,通过session和cookie的通信一般就失去效益了,即使这么做了也会产生著名的漏洞问题CSRF(Cross-site request forgery), 是一种挟制用户在当前已登录的Web应用程序上执行非本意的操作的攻击方法。因…

uni-app 如何全局设置,获取app.vue里面的值

在globalData里设置一个值 通过下面方法修改 this.$options.globalData.$versonStatus status 在页面中通过getApp()获取 getApp().globalData.$versonStatus

Android中的View绘制流程

Android中的View绘制流程是一个复杂而精细的过程,它确保了应用程序中的用户界面能够准确、高效地呈现在用户眼前。以下将详细阐述Android View的绘制流程,包括测量(Measure)、布局(Layout)和绘制&#xff0…

如何设置 GitLab 密码长度?

GitLab 是一个全球知名的一体化 DevOps 平台,很多人都通过私有化部署 GitLab 来进行源代码托管。极狐GitLab 是 GitLab 在中国的发行版,专门为中国程序员服务。可以一键式部署极狐GitLab。 学习极狐GitLab 的相关资料: 极狐GitLab 60天专业…

【mysql】数据库分区的使用

【mysql】数据库分区的使用 【一】分区的基本概念【1】物理存储与逻辑分割【2】查询性能提升【3】数据管理与维护【4】扩展性与并行处理 【二】分区的原理和类型【1】InnoDB逻辑存储结构【2】分区的原理【2】分区类型 【三】分区的优势和使用场景【四】如何实施分区【五】分区表…

python pip安装requirements.txt依赖与国内镜像

python pip安装requirements.txt依赖与国内镜像 如果网络通畅,直接pip安装依赖: pip install -r requirements.txt 如果需要国内的镜像,可以考虑使用阿里的,在后面加上: -i http://mirrors.aliyun.com/pypi/simple --…

三层b+树估算存储多少行数据

文章目录 B树结构图示估算方法(这里要以聚簇索引来看) B树结构图示 估算方法(这里要以聚簇索引来看) 非叶子节点数* 每个叶子结点记录总数 假设mysql 数据页,16kb,刚好对应B树的一个节点 每个叶子结点记录数, 叶子结点存储的是对应的原始数据…