Kaiming Uniform 初始化:神经网络权重初始化的优雅解决方案
在深度学习的模型训练中,权重初始化的选择对网络的收敛速度和性能有着深远影响。传统的随机初始化(如高斯分布)在浅层网络中尚可接受,但随着网络深度增加,梯度消失或爆炸的问题变得愈发严重。为了解决这一问题,Kaiming He 等人在 2015 年提出了 Kaiming 初始化(也称为 He 初始化),其中 kaiming_uniform_
是 PyTorch 中实现的一种均匀分布形式。本文将深入探讨其原理、实现细节以及在 LoRA(Low-Rank Adaptation)中的应用。
一、背景与动机
在深度神经网络(如 CNN 或 Transformer)中,前向传播和反向传播的稳定性依赖于每一层的输入和输出的数值范围。如果权重初始化不当,可能导致:
- 梯度消失:每一层的激活值过小,梯度在反向传播中逐渐衰减至零。
- 梯度爆炸:激活值过大,梯度在反向传播中指数增长,导致训练不稳定。
Kaiming 初始化通过分析网络的方差传播,提出了一种基于层输入和输出维度的初始化方法,确保信号在深层网络中的稳定传递。它最初在论文 Delving Deep into Rectifiers: Surpassing Human-Level Performance on ImageNet Classification(He et al., 2015)中提出,针对 ReLU 激活函数设计,后来被广泛应用于各类网络。
二、Kaiming Uniform 的数学原理
Kaiming 初始化假设网络使用 ReLU 或其变体(如 Leaky ReLU)作为激活函数,其核心目标是保持每一层的输出方差与输入方差一致。
1. 前向传播的方差分析
考虑一个线性层:
y = W x y = W x y=Wx
- ( x ∈ R n i n x \in \mathbb{R}^{n_{in}} x∈Rnin ):输入向量,( n i n n_{in} nin ) 是输入维度。
- ( W ∈ R n o u t × n i n W \in \mathbb{R}^{n_{out} \times n_{in}} W∈Rnout×nin ):权重矩阵,( n o u t n_{out} nout ) 是输出维度。
- ( y ∈ R n o u t y \in \mathbb{R}^{n_{out}} y∈Rnout ):输出向量。
假设:
- ( x x x ) 的每个元素是独立同分布(i.i.d.),方差为 ( Var ( x ) \text{Var}(x) Var(x) )。
- ( W W W ) 的每个元素也是 i.i.d.,初始方差为 ( Var ( W ) \text{Var}(W) Var(W) )。
则输出 ( y i y_i yi ) 的方差为:
Var ( y i ) = Var ( ∑ j = 1 n i n W i j x j ) = ∑ j = 1 n i n Var ( W i j x j ) \text{Var}(y_i) = \text{Var}\left( \sum_{j=1}^{n_{in}} W_{ij} x_j \right) = \sum_{j=1}^{n_{in}} \text{Var}(W_{ij} x_j) Var(yi)=Var(j=1∑ninWijxj)=j=1∑ninVar(Wijxj)
若 ( W i j W_{ij} Wij ) 和 ( x j x_j xj ) 独立:
Var ( y i ) = n i n ⋅ Var ( W i j ) ⋅ Var ( x j ) = n i n ⋅ Var ( W ) ⋅ Var ( x ) \text{Var}(y_i) = n_{in} \cdot \text{Var}(W_{ij}) \cdot \text{Var}(x_j) = n_{in} \cdot \text{Var}(W) \cdot \text{Var}(x) Var(yi)=nin⋅Var(Wij)⋅Var(xj)=nin⋅Var(W)⋅Var(x)
对于 ReLU 激活 ( f ( y ) = max ( 0 , y ) f(y) = \max(0, y) f(y)=max(0,y) ),输出方差减半(因为一半的激活被置零):
Var ( f ( y ) ) = 1 2 n i n ⋅ Var ( W ) ⋅ Var ( x ) \text{Var}(f(y)) = \frac{1}{2} n_{in} \cdot \text{Var}(W) \cdot \text{Var}(x) Var(f(y))=21nin⋅Var(W)⋅Var(x)
为了保持 ( Var ( f ( y ) ) = Var ( x ) \text{Var}(f(y)) = \text{Var}(x) Var(f(y))=Var(x) ):
1 2 n i n ⋅ Var ( W ) = 1 ⟹ Var ( W ) = 2 n i n \frac{1}{2} n_{in} \cdot \text{Var}(W) = 1 \implies \text{Var}(W) = \frac{2}{n_{in}} 21nin⋅Var(W)=1⟹Var(W)=nin2
2. 均匀分布的参数
Kaiming 初始化使用均匀分布 ( U ( − a , a ) U(-a, a) U(−a,a) ) 初始化权重,其中 ( a a a ) 是边界值。对于均匀分布:
Var ( W ) = ( a − ( − a ) ) 2 12 = ( 2 a ) 2 12 = 4 a 2 12 = a 2 3 \text{Var}(W) = \frac{(a - (-a))^2}{12} = \frac{(2a)^2}{12} = \frac{4a^2}{12} = \frac{a^2}{3} Var(W)=12(a−(−a))2=12(2a)2=124a2=3a2
令:
a 2 3 = 2 n i n \frac{a^2}{3} = \frac{2}{n_{in}} 3a2=nin2
解得:
a = 6 n i n a = \sqrt{\frac{6}{n_{in}}} a=nin6
因此,权重初始化为:
W ∼ U ( − 6 n i n , 6 n i n ) W \sim U\left(-\sqrt{\frac{6}{n_{in}}}, \sqrt{\frac{6}{n_{in}}}\right) W∼U(−nin6,nin6)
3. 参数 ( a a a ) 的调整
在 PyTorch 的 nn.init.kaiming_uniform_
中,参数 ( a a a ) 允许用户调整分布的宽度,默认值为 ( 5 \sqrt{5} 5 )(对应 Xavier 初始化的一种变体)。但在 Kaiming 原始设计中,( a = 0 a = 0 a=0 )(纯 ReLU)或基于激活函数的斜率调整。
三、PyTorch 中的实现
PyTorch 提供了 nn.init.kaiming_uniform_
函数,签名如下:
torch.nn.init.kaiming_uniform_(tensor, a=0, mode='fan_in', nonlinearity='relu')
- tensor:要初始化的张量,通常是权重矩阵(形状为 ( [ n o u t , n i n ] [n_{out}, n_{in}] [nout,nin] ))。
- a:负斜率参数,对于 ReLU 通常为 0,对于 Leaky ReLU 为斜率值。
- mode:选择
fan_in
(输入维度,推荐用于前向稳定性)或fan_out
(输出维度,用于反向稳定性)。 - nonlinearity:激活函数类型(如
'relu'
或'leaky_relu'
)。
在 LoRA 的代码中(例如 Microsoft LoRA 实现, https://github.com/huggingface/peft/blob/main/src/peft/tuners/lora/layer.py):
nn.init.kaiming_uniform_(self.lora_A[adapter_name].weight, a=math.sqrt(5))
- 这里 ( a = 5 a = \sqrt{5} a=5 ) 是 Xavier 初始化的遗留参数,但实际效果接近 Kaiming。
self.lora_A.weight
是 ( [ d , r ] [d, r] [d,r] ) 形状的矩阵,( d d d ) 是输入维度,( r r r ) 是秩。
四、在 LoRA 中的应用
LoRA(Low-Rank Adaptation)通过低秩矩阵 ( A A A ) 和 ( B B B ) 对预训练权重进行微调。在 reset_lora_parameters
函数中:
- ( A A A ) 使用
kaiming_uniform_
初始化:nn.init.kaiming_uniform_(self.lora_A[adapter_name].weight, a=math.sqrt(5))
- ( B B B ) 初始化为零:
nn.init.zeros_(self.lora_B[adapter_name].weight)
1. 为什么用 Kaiming 初始化 ( A A A )?
- 数值稳定性:( A ∈ R d × r A \in \mathbb{R}^{d \times r} A∈Rd×r ) 的初始值需要适度随机化,以确保训练早期 ( Δ W = A B \Delta W = A B ΔW=AB) 能够探索不同的低秩子空间。Kaiming 初始化根据输入维度 ( d d d ) 设置合理的方差,避免 ( A A A ) 的值过大或过小。
- 与 ( B ) 零初始化的配合:( B = 0 B = 0 B=0 ) 确保初始 ( Δ W = 0 \Delta W = 0 ΔW=0),而 ( A A A ) 的 Kaiming 初始化为后续更新提供了非零起点,避免训练陷入全零状态。
- 兼容性:LoRA 常用于 Transformer 等深层模型,Kaiming 初始化与这些架构的 ReLU 或 GELU 激活兼容。
2. ( a = 5 a = \sqrt{5} a=5 ) 的意义
- ( a = 5 a = \sqrt{5} a=5 ) 实际上是向后兼容 Xavier 初始化的一种选择,但与 Kaiming 的 ( a = 0 a = 0 a=0 )(纯 ReLU)略有不同。它使得分布稍宽,可能增加初始探索性,但仍接近 ( 6 n i n \sqrt{\frac{6}{n_{in}}} nin6 ) 的理论值。
五、优势与局限性
优势
- 稳定性:通过控制方差,Kaiming 初始化显著减少了深层网络中的梯度问题。
- 通用性:适用于大多数使用 ReLU 及其变体的网络。
- 简单性:只需输入维度即可计算,无需复杂调参。
局限性
- 激活函数依赖:主要针对 ReLU 设计,对于 Sigmoid 或 Tanh 等激活可能不适用。
- 固定假设:假设输入是 i.i.d.,在实际复杂数据中可能不完全成立。
- LoRA 场景的特殊性:( r r r ) 远小于 ( d d d ),低秩约束可能削弱初始化的理论效果。
六、扩展与改进方向
对于 LoRA 或其他场景,可以考虑以下改进:
- 动态调整 ( a a a ):根据 ( r r r ) 或任务复杂度自适应选择 ( a a a ),而不是固定为 ( 5 \sqrt{5} 5 )。
- 正交初始化:对于低秩矩阵 ( A A A ),尝试
torch.nn.init.orthogonal_
,可能提升多样性。 - 与 scaling 结合:LoRA 中的
scaling = lora_alpha / r
可以与初始化协同设计,避免重复调整幅度。
七、结语
nn.init.kaiming_uniform_
是深度学习中权重初始化的经典方法,通过均匀分布确保信号在网络中的稳定传播。在 LoRA 中,它为 ( A A A ) 提供了合理的初始值,与 ( B B B ) 的零初始化配合,兼顾了稳定性和学习能力。对于研究者而言,理解其背后的数学原理不仅有助于调参,还能启发新的初始化策略。欢迎在评论区分享你的使用经验或改进想法!
后记
2025年3月11日22点44分于上海,在Grok 3大模型辅助下完成。