Llama改进之——均方根层归一化RMSNorm

server/2024/9/24 0:29:28/

引言

在学习完GPT2之后,从本文开始进入Llama模型系列。

本文介绍Llama模型的改进之RMSNorm(均方根层归一化)。它是由Root Mean Square Layer Normalization论文提出来的,可以参阅其论文笔记1

LayerNorm

层归一化(LayerNorm)对Transformer等模型来说非常重要,它可以帮助稳定训练并提升模型收敛性。LayerNorm针对一个样本所有特征计算均值和方差,然后使用这些来对样本进行归一化:
μ = 1 H ∑ i = 1 H x i , σ = 1 H ∑ i = 1 H ( x i − μ ) 2 , N ( x ) = x − μ σ , h = g ⊙ N ( x ) + b (1) \mu = \frac{1}{H}\sum_{i=1}^H x_i,\quad \sigma = \sqrt{\frac{1}{H}\sum_{i=1}^H (x_i - \mu)^2}, \quad N(\pmb x) = \frac{\pmb x-\mu}{\sigma},\quad \pmb h = \pmb g \,\odot N(\pmb x) + \pmb b \tag 1 μ=H1i=1Hxi,σ=H1i=1H(xiμ)2 ,N(x)=σxμ,h=gN(x)+b(1)
这里 x = ( x 1 , x 2 , ⋯ , x H ) \pmb x = (x_1,x_2,\cdots, x_H) x=(x1,x2,,xH)表示某个时间步LN层的输入向量表示,向量维度为 H H H h \pmb h h实LN层的输出; g , b \pmb g,\pmb b g,b实两个可学习的参数。

为什么层归一化有用?一些解释如下2

  1. 减少内部协变量偏移(Internal Covariate Shift): 内部协变量偏移是指在深度神经网络的训练过程中,每一层输入的分布会发生变化,导致网络的训练变得困难。层归一化通过对每一层的输入进行归一化处理,可以减少内部协变量偏移,使得每一层的输入分布更加稳定。
  2. 稳定化梯度: 层归一化有助于保持每一层输出的均值和方差稳定,从而使得梯度的传播更加稳定。这有助于减少梯度消失或梯度爆炸的问题,提高梯度在网络中的流动性,加快训练速度。
  3. 更好的参数初始化和学习率调整: 通过层归一化,每一层的输入分布被归一化到均值为0、方差为1的标准正态分布,这有助于更好地初始化网络参数和调整学习率。参数初始化与学习率调整的稳定性对模型的训练效果至关重要。
  4. 增强模型的泛化能力: 层归一化可以减少网络对训练数据分布的依赖,降低了过拟合的风险,从而提高模型的泛化能力。稳定的输入分布有助于模型更好地适应不同数据集和任务。

RMSNorm

虽然LayerNorm很好,但是它每次需要计算均值和方差。RMSNorm的思想就是移除(1)式中 μ \mu μ的计算部分1
x ˉ i = x i RMS ( x ) g i RMS ( x ) = 1 H ∑ i = 1 H x i 2 (2) \bar x_i = \frac{x_i }{ \text{RMS}(\pmb x)} g_i \quad \text{RMS}(\pmb x) =\sqrt{\frac{1}{H} \sum_{i=1}^H x_i^2} \tag 2 xˉi=RMS(x)xigiRMS(x)=H1i=1Hxi2 (2)

同时在实现也可以移除平移偏置 b \pmb b b

单看(2)式的话,相当于仅使用 x \pmb x x的均方根来对输入进行归一化,它简化了层归一化的计算,变得更加高效,同时还有可能带来性能上的提升。

实现

RMSNorm的实现很简单:

import torch
import torch.nn as nn
from torch import Tensorclass RMSNorm(nn.Module):def __init__(self, hidden_size: int, eps: float = 1e-6) -> None:super().__init__()self.eps = epsself.weight = nn.Parameter(torch.ones(hidden_size))def _norm(self, hidden_states: Tensor) -> Tensor:variance = hidden_states.pow(2).mean(-1, keepdim=True)return hidden_states * torch.rsqrt(variance + self.eps)def forward(self, hidden_states: Tensor) -> Tensor:return self.weight * self._norm(hidden_states.float()).type_as(hidden_states)

torch.rsqrttorch.sqrt的倒数;eps是一个很小的数,防止除零;hidden_states.float()确保了标准差计算的精确度和稳定性,然后在forward方法中,通过.type_as(hidden_states)将结果转换回原来的数据类型,以保持与输入张量相同的数据类型,使得归一化处理后的结果与输入数据类型一致。

下面通过一个简单的网络来测试一下:

import torch
import torch.nn as nn
from torch import Tensorclass SimpleNet(nn.Module):def __init__(self):super(SimpleNet, self).__init__()self.linear = nn.Linear(in_features=10, out_features=5)self.rmsnorm = RMSNorm(hidden_size=5)def forward(self, x):x = self.linear(x)x = self.rmsnorm(x)return xnet = SimpleNet()input_data = torch.randn(2, 10)  # 2个样本,每个样本包含10个特征output = net(input_data)print("Input Shape:", input_data.shape)
print("Output Shape:", output.shape)
Input Shape: torch.Size([2, 10])
Output Shape: torch.Size([2, 5])

参考


  1. [论文笔记]Root Mean Square Layer Normalization ↩︎ ↩︎

  2. 批归一化和层归一化 ↩︎


http://www.ppmy.cn/server/11956.html

相关文章

【软考经验分享】软考-中级-嵌入式备考

这里写目录标题 教辅用书嵌入式系统设计师考试大纲嵌入式系统设计师教程嵌入式系统设计师5天修炼嵌入式系统设计师考前冲刺100题 刷题软件希赛网软考真题 视频教程希赛网王道-计组计网 教辅用书 嵌入式系统设计师考试大纲 50页左右,内容为罗列一些考点&#xff0c…

通过“命令提示符(cmd)”注销后台帐号用户

通过“命令提示符(cmd)”注销后台帐号用户 1 2 3 4 分步阅读 电脑上面后台使用的用户较多(包括远程连接),电脑的运行负荷将会增加,电脑响应缓慢,甚至会影响到正常的使用&#xff0c…

mysql基础19——日志

日志 mysql的日志种类非常多 通用查询日志 慢查询日志 错误日志 与时间有关联 二进制日志 中继日志 与主从服务器的同步有关 重做日志 回滚日志 与数据丢失有关 通用查询日志 记录了所有用户的连接开始时间和截至时间 以及给mysql服务器发送的所有指令 当数据异常时&…

Redhawk:ATE如何产生top level sta file

我正在「拾陆楼」和朋友们讨论有趣的话题,你⼀起来吧? 拾陆楼知识星球入口 相关文章链接 redhawk: create STA file 在“redhawk: create STA file”一文中介绍了ate的用法,可以应对block level的设计,但当需要做top level分析时&

C语言学习/复习29--内存操作函数memcpy/memmove/memset/memcmp

一、内存操作函数 1.memcpy()函数 注意事项1:复制的数目以字节为单位 注意事项2:一定要保证有足够空间复制 模拟实现1 拷贝字符案例:由于拷贝时函数本事就以字节为单位拷贝所以该例子也可用于其他类型数据的拷贝。 模拟实现2 将自身的…

WebSocket的原理、作用、常见注解和生命周期的简单介绍,附带SpringBoot示例

文章目录 WebSocket是什么WebSocket的原理WebSocket的作用全双工和半双工客户端【浏览器】API服务端 【Java】APIWebSocket的生命周期WebSocket的常见注解SpringBoot简单代码示例 WebSocket是什么 WebSocket是一种 通信协议 ,它在 客户端和服务器之间建立了一个双向…

AGI的智力有可能在两年内超过人类水平

特斯拉CEO埃隆马斯克近日与挪威银行投资管理基金CEO坦根的访谈中表示,AGI的智力将在两年内可能超过人类智力,在未来五年内,AI的能力很可能超过所有人类。 马斯克透漏,去年人工智能发展过程中的主要制约因素是缺少高性能芯片&#…

基于C++ DNN部署Yolov8出现的问题记录

代码问题 报错行:net.forward(outputs, net.getUnconnectedOutLayersNames()) 错误展示 错误代码:Exception message: OpenCV(4.8.1) C:\GHA-OCV-2\_work\ci-gha-workflow\ci-gha-workflow\opencv\modules\dnn\src\layers\reshape_layer.cpp:109: err…