通过模拟对CLIP进行解释:如何通过梯度提升正样本的相似度?
具体CLIP可以参考笔者的另外的博客: CLIP 的核心训练代码与对比损失的解释:中英双语 和 对比损失(Contrastive Loss)与大模型:Contrastive Loss and Large Models (中英双语)
交叉熵损失在 CLIP 中的工作原理
-
相似性矩阵(Logits):
logits_per_image
是一个 ( batch_size × batch_size \text{batch\_size} \times \text{batch\_size} batch_size×batch_size ) 的矩阵。- 例如,假设 batch size 为 4:
logits_per_image = [ 1.2 0.3 − 0.8 0.5 0.4 1.5 0.1 − 0.2 0.0 − 0.3 2.0 0.6 − 0.5 0.7 0.8 1.3 ] \text{logits\_per\_image} = \begin{bmatrix} 1.2 & 0.3 & -0.8 & 0.5 \\ 0.4 & 1.5 & 0.1 & -0.2 \\ 0.0 & -0.3 & 2.0 & 0.6 \\ -0.5 & 0.7 & 0.8 & 1.3 \end{bmatrix} logits_per_image= 1.20.40.0−0.50.31.5−0.30.7−0.80.12.00.80.5−0.20.61.3 - 对角线上的值是正样本的相似度,其他值是负样本的相似度。
-
交叉熵损失的计算:
- 对于每一行(例如第 ( i i i ) 行),交叉熵损失希望第 ( i i i ) 列的值(正样本)最大化,而其他列(负样本)最小化。
- 计算公式如下:
CrossEntropyLoss = − 1 N ∑ i = 1 N log exp ( logits [ i , i ] ) ∑ j = 1 N exp ( logits [ i , j ] ) \text{CrossEntropyLoss} = -\frac{1}{N} \sum_{i=1}^N \log \frac{\exp(\text{logits}[i, i])}{\sum_{j=1}^N \exp(\text{logits}[i, j])} CrossEntropyLoss=−N1i=1∑Nlog∑j=1Nexp(logits[i,j])exp(logits[i,i]) - 该公式中的分母(归一化项)将正样本和负样本的相似度联合建模,形成一种竞争关系。
-
正负样本的距离调节:
- 拉近正样本距离:通过最大化正样本(对角线值)在 softmax 分布中的概率。
- 拉远负样本距离:通过对其他负样本的值施加抑制,使它们的 softmax 概率接近 0。
在这个例子中,我们通过 梯度下降法 来优化 logits_per_image
中的正样本分数(即对角线上的值,例如 2.0)。以下是详细的步骤,包括梯度计算、权重更新以及下一轮的损失变化。
假设前提
- 初始相似度矩阵(
logits_per_image
):
logits_per_image = [ 2.0 0.5 − 1.0 0.3 1.8 0.2 − 0.5 0.4 1.5 ] \text{logits\_per\_image} = \begin{bmatrix} 2.0 & 0.5 & -1.0 \\ 0.3 & 1.8 & 0.2 \\ -0.5 & 0.4 & 1.5 \end{bmatrix} logits_per_image= 2.00.3−0.50.51.80.4−1.00.21.5 - 学习率:( η = 0.1 \eta = 0.1 η=0.1 )
- 批量大小:( batch_size = 3 \text{batch\_size} = 3 batch_size=3 )
- 目标:通过一轮梯度下降,提升正样本的相似度,同时降低负样本的相似度。
Step 1: 计算梯度
Softmax 概率计算
以第一行 ( logits [ 0 , : ] \text{logits}[0, :] logits[0,:] ) 为例:
logits [ 0 , : ] = [ 2.0 , 0.5 , − 1.0 ] \text{logits}[0, :] = [2.0, 0.5, -1.0] logits[0,:]=[2.0,0.5,−1.0]
对应的 softmax 概率为:
P [ i , j ] = exp ( logits [ i , j ] ) ∑ k exp ( logits [ i , k ] ) P[i, j] = \frac{\exp(\text{logits}[i, j])}{\sum_{k} \exp(\text{logits}[i, k])} P[i,j]=∑kexp(logits[i,k])exp(logits[i,j])
计算分母:
denominator = exp ( 2.0 ) + exp ( 0.5 ) + exp ( − 1.0 ) ≈ 7.389 + 1.649 + 0.368 = 9.406 \text{denominator} = \exp(2.0) + \exp(0.5) + \exp(-1.0) \approx 7.389 + 1.649 + 0.368 = 9.406 denominator=exp(2.0)+exp(0.5)+exp(−1.0)≈7.389+1.649+0.368=9.406
正样本 ( P ( positive ) P(\text{positive}) P(positive) ):
P ( positive ) = exp ( 2.0 ) denominator = 7.389 9.406 ≈ 0.785 P(\text{positive}) = \frac{\exp(2.0)}{\text{denominator}} = \frac{7.389}{9.406} \approx 0.785 P(positive)=denominatorexp(2.0)=9.4067.389≈0.785
负样本 ( P ( negative , j = 2 ) P(\text{negative}, j=2) P(negative,j=2) ):
P ( negative , j = 2 ) = exp ( 0.5 ) denominator = 1.649 9.406 ≈ 0.175 P(\text{negative}, j=2) = \frac{\exp(0.5)}{\text{denominator}} = \frac{1.649}{9.406} \approx 0.175 P(negative,j=2)=denominatorexp(0.5)=9.4061.649≈0.175
负样本 ( P ( negative , j = 3 ) P(\text{negative}, j=3) P(negative,j=3) ):
P ( negative , j = 3 ) = exp ( − 1.0 ) denominator = 0.368 9.406 ≈ 0.039 P(\text{negative}, j=3) = \frac{\exp(-1.0)}{\text{denominator}} = \frac{0.368}{9.406} \approx 0.039 P(negative,j=3)=denominatorexp(−1.0)=9.4060.368≈0.039
交叉熵损失对 logits 的梯度
交叉熵损失公式:
Loss i = − log ( P ( positive ) ) \text{Loss}_i = -\log(P(\text{positive})) Lossi=−log(P(positive))
对 ( logits [ 0 , : ] \text{logits}[0, :] logits[0,:] ) 的梯度计算:
∂ Loss i ∂ logits [ 0 , j ] = P [ i , j ] − δ i , j \frac{\partial \text{Loss}_i}{\partial \text{logits}[0, j]} = P[i, j] - \delta_{i, j} ∂logits[0,j]∂Lossi=P[i,j]−δi,j
其中 ( δ i , j \delta_{i, j} δi,j ) 是 Kronecker delta,表示只有正样本(即对角线)位置是 1,其余为 0。
对于第一行:
- 正样本 ( j = 0 j=0 j=0 ):
∂ Loss 0 ∂ logits [ 0 , 0 ] = P [ 0 , 0 ] − 1 = 0.785 − 1 = − 0.215 \frac{\partial \text{Loss}_0}{\partial \text{logits}[0, 0]} = P[0, 0] - 1 = 0.785 - 1 = -0.215 ∂logits[0,0]∂Loss0=P[0,0]−1=0.785−1=−0.215 - 负样本 ( j = 1 j=1 j=1 ):
∂ Loss 0 ∂ logits [ 0 , 1 ] = P [ 0 , 1 ] = 0.175 \frac{\partial \text{Loss}_0}{\partial \text{logits}[0, 1]} = P[0, 1] = 0.175 ∂logits[0,1]∂Loss0=P[0,1]=0.175 - 负样本 ( j = 2 j=2 j=2 ):
∂ Loss 0 ∂ logits [ 0 , 2 ] = P [ 0 , 2 ] = 0.039 \frac{\partial \text{Loss}_0}{\partial \text{logits}[0, 2]} = P[0, 2] = 0.039 ∂logits[0,2]∂Loss0=P[0,2]=0.039
Step 2: 更新 logits
使用梯度下降法更新:
logits [ i , j ] = logits [ i , j ] − η ⋅ ∂ Loss i ∂ logits [ i , j ] \text{logits}[i, j] = \text{logits}[i, j] - \eta \cdot \frac{\partial \text{Loss}_i}{\partial \text{logits}[i, j]} logits[i,j]=logits[i,j]−η⋅∂logits[i,j]∂Lossi
对于第一行:
- 正样本 ( j = 0 j=0 j=0 ):
logits [ 0 , 0 ] = 2.0 − 0.1 ⋅ ( − 0.215 ) = 2.0 + 0.0215 = 2.0215 \text{logits}[0, 0] = 2.0 - 0.1 \cdot (-0.215) = 2.0 + 0.0215 = 2.0215 logits[0,0]=2.0−0.1⋅(−0.215)=2.0+0.0215=2.0215 - 负样本 ( j = 1 j=1 j=1 ):
logits [ 0 , 1 ] = 0.5 − 0.1 ⋅ 0.175 = 0.5 − 0.0175 = 0.4825 \text{logits}[0, 1] = 0.5 - 0.1 \cdot 0.175 = 0.5 - 0.0175 = 0.4825 logits[0,1]=0.5−0.1⋅0.175=0.5−0.0175=0.4825 - 负样本 ( j = 2 j=2 j=2 ):
logits [ 0 , 2 ] = − 1.0 − 0.1 ⋅ 0.039 = − 1.0 − 0.0039 = − 1.0039 \text{logits}[0, 2] = -1.0 - 0.1 \cdot 0.039 = -1.0 - 0.0039 = -1.0039 logits[0,2]=−1.0−0.1⋅0.039=−1.0−0.0039=−1.0039
更新后的第一行 logits:
logits [ 0 , : ] = [ 2.0215 , 0.4825 , − 1.0039 ] \text{logits}[0, :] = [2.0215, 0.4825, -1.0039] logits[0,:]=[2.0215,0.4825,−1.0039]
Step 3: 下一轮的损失计算
使用更新后的 logits 重新计算 softmax 概率和损失:
logits [ 0 , : ] = [ 2.0215 , 0.4825 , − 1.0039 ] \text{logits}[0, :] = [2.0215, 0.4825, -1.0039] logits[0,:]=[2.0215,0.4825,−1.0039]
分母:
denominator = exp ( 2.0215 ) + exp ( 0.4825 ) + exp ( − 1.0039 ) ≈ 7.562 + 1.620 + 0.367 = 9.549 \text{denominator} = \exp(2.0215) + \exp(0.4825) + \exp(-1.0039) \approx 7.562 + 1.620 + 0.367 = 9.549 denominator=exp(2.0215)+exp(0.4825)+exp(−1.0039)≈7.562+1.620+0.367=9.549
正样本概率:
P ( positive ) = exp ( 2.0215 ) denominator = 7.562 9.549 ≈ 0.792 P(\text{positive}) = \frac{\exp(2.0215)}{\text{denominator}} = \frac{7.562}{9.549} \approx 0.792 P(positive)=denominatorexp(2.0215)=9.5497.562≈0.792
损失:
Loss = − log ( P ( positive ) ) = − log ( 0.792 ) ≈ 0.233 \text{Loss} = -\log(P(\text{positive})) = -\log(0.792) \approx 0.233 Loss=−log(P(positive))=−log(0.792)≈0.233
对比:
- 更新前损失:( 0.34 )
- 更新后损失:( 0.233 )
总结
通过一轮梯度下降:
- 正样本相似度提升(从 2.0 增加到 2.0215)。
- 负样本相似度下降(例如从 0.5 减少到 0.4825)。
- 损失函数减小,从 ( 0.34 ) 减少到 ( 0.233 ),说明模型朝着更优的方向优化。
这种优化方式有效地让正样本对更加接近,而负样本对更加远离,从而提升模型在对比学习任务中的表现。
后记
2024年12月13日21点56分于上海,在GPT4o大模型辅助下完成。