目录
- 代码
- 代码解释
- 1. 初始化方法 `__init__`
- 2. 前向传播方法 `forward`
- 3. 总结
- 4. 使用场景
- 可视化
代码
class RMSNorm(torch.nn.Module):def __init__(self, dim: int, eps: float):super().__init__()self.eps = epsself.weight = nn.Parameter(torch.ones(dim))def forward(self, x):return self.weight * (x.float() * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)).type_as(x)
代码解释
这段代码定义了一个自定义的PyTorch模块 RMSNorm
,用于实现Root Mean Square Normalization (RMSNorm)。RMSNorm是一种归一化技术,类似于Layer Normalization,但它只对输入进行缩放,而不进行平移(即没有偏置项)。下面是代码的详细解释:
1. 初始化方法 __init__
def __init__(self, dim: int, eps: float):super().__init__()self.eps = epsself.weight = nn.Parameter(torch.ones(dim))
dim: int
: 输入特征的维度。eps: float
: 一个小常数,用于数值稳定性,避免除以零的情况。self.weight
: 一个可学习的参数,形状为(dim,)
,初始化为全1的张量。这个参数用于对归一化后的输入进行缩放。
2. 前向传播方法 forward
def forward(self, x):return self.weight * (x.float() * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)).type_as(x)
x
: 输入张量,形状通常为(batch_size, ..., dim)
。x.pow(2)
: 对输入x
的每个元素求平方。x.pow(2).mean(-1, keepdim=True)
: 沿着最后一个维度(即特征维度dim
)计算平方的均值,并保持维度不变。结果形状为(batch_size, ..., 1)
。torch.rsqrt(...)
: 计算均方根的倒数(即1除以平方根),用于归一化。x.float() * torch.rsqrt(...)
: 将输入x
转换为浮点数后,乘以均方根的倒数,得到归一化后的结果。.type_as(x)
: 将结果转换回与输入x
相同的数据类型。self.weight * (...)
: 最后,将归一化后的结果乘以可学习的权重self.weight
,进行缩放。
3. 总结
- RMSNorm 通过对输入进行归一化,使得每个特征的均方根值为1,然后通过可学习的权重进行缩放。
- 与LayerNorm不同,RMSNorm没有偏置项,只进行缩放操作。
eps
用于防止除以零的情况,增加数值稳定性。
4. 使用场景
RMSNorm通常用于深度学习模型中,特别是在Transformer架构中,作为LayerNorm的替代方案。它可以加速训练并提高模型的稳定性。
可视化
dim = 64
eps = 1e-5
m = RMSNorm(dim, eps)
x = torch.randn(32, 10, dim) # 示例输入 (batch_size, seq_len, dim)f = "rms_norm.onnx" # 导出的 ONNX 文件名
torch.onnx.export(m, x, f) # 模型 # 示例输入
在 https://netron.app/
上打开 rms_norm.onnx