通过模拟对CLIP进行解释:如何通过梯度提升正样本的相似度?

server/2024/12/14 10:42:29/

通过模拟对CLIP进行解释:如何通过梯度提升正样本的相似度?

具体CLIP可以参考笔者的另外的博客: CLIP 的核心训练代码与对比损失的解释:中英双语 和 对比损失(Contrastive Loss)与大模型:Contrastive Loss and Large Models (中英双语)

交叉熵损失在 CLIP 中的工作原理
  1. 相似性矩阵(Logits)

    • logits_per_image 是一个 ( batch_size × batch_size \text{batch\_size} \times \text{batch\_size} batch_size×batch_size ) 的矩阵。
    • 例如,假设 batch size 为 4:
      logits_per_image = [ 1.2 0.3 − 0.8 0.5 0.4 1.5 0.1 − 0.2 0.0 − 0.3 2.0 0.6 − 0.5 0.7 0.8 1.3 ] \text{logits\_per\_image} = \begin{bmatrix} 1.2 & 0.3 & -0.8 & 0.5 \\ 0.4 & 1.5 & 0.1 & -0.2 \\ 0.0 & -0.3 & 2.0 & 0.6 \\ -0.5 & 0.7 & 0.8 & 1.3 \end{bmatrix} logits_per_image= 1.20.40.00.50.31.50.30.70.80.12.00.80.50.20.61.3
    • 对角线上的值是正样本的相似度,其他值是负样本的相似度。
  2. 交叉熵损失的计算

    • 对于每一行(例如第 ( i i i ) 行),交叉熵损失希望第 ( i i i ) 列的值(正样本)最大化,而其他列(负样本)最小化。
    • 计算公式如下:
      CrossEntropyLoss = − 1 N ∑ i = 1 N log ⁡ exp ⁡ ( logits [ i , i ] ) ∑ j = 1 N exp ⁡ ( logits [ i , j ] ) \text{CrossEntropyLoss} = -\frac{1}{N} \sum_{i=1}^N \log \frac{\exp(\text{logits}[i, i])}{\sum_{j=1}^N \exp(\text{logits}[i, j])} CrossEntropyLoss=N1i=1Nlogj=1Nexp(logits[i,j])exp(logits[i,i])
    • 该公式中的分母(归一化项)将正样本和负样本的相似度联合建模,形成一种竞争关系。
  3. 正负样本的距离调节

    • 拉近正样本距离:通过最大化正样本(对角线值)在 softmax 分布中的概率。
    • 拉远负样本距离:通过对其他负样本的值施加抑制,使它们的 softmax 概率接近 0。

在这个例子中,我们通过 梯度下降法 来优化 logits_per_image 中的正样本分数(即对角线上的值,例如 2.0)。以下是详细的步骤,包括梯度计算、权重更新以及下一轮的损失变化。


假设前提

  • 初始相似度矩阵(logits_per_image):
    logits_per_image = [ 2.0 0.5 − 1.0 0.3 1.8 0.2 − 0.5 0.4 1.5 ] \text{logits\_per\_image} = \begin{bmatrix} 2.0 & 0.5 & -1.0 \\ 0.3 & 1.8 & 0.2 \\ -0.5 & 0.4 & 1.5 \end{bmatrix} logits_per_image= 2.00.30.50.51.80.41.00.21.5
  • 学习率:( η = 0.1 \eta = 0.1 η=0.1 )
  • 批量大小:( batch_size = 3 \text{batch\_size} = 3 batch_size=3 )
  • 目标:通过一轮梯度下降,提升正样本的相似度,同时降低负样本的相似度。

Step 1: 计算梯度

Softmax 概率计算

以第一行 ( logits [ 0 , : ] \text{logits}[0, :] logits[0,:] ) 为例:
logits [ 0 , : ] = [ 2.0 , 0.5 , − 1.0 ] \text{logits}[0, :] = [2.0, 0.5, -1.0] logits[0,:]=[2.0,0.5,1.0]
对应的 softmax 概率为:
P [ i , j ] = exp ⁡ ( logits [ i , j ] ) ∑ k exp ⁡ ( logits [ i , k ] ) P[i, j] = \frac{\exp(\text{logits}[i, j])}{\sum_{k} \exp(\text{logits}[i, k])} P[i,j]=kexp(logits[i,k])exp(logits[i,j])
计算分母:
denominator = exp ⁡ ( 2.0 ) + exp ⁡ ( 0.5 ) + exp ⁡ ( − 1.0 ) ≈ 7.389 + 1.649 + 0.368 = 9.406 \text{denominator} = \exp(2.0) + \exp(0.5) + \exp(-1.0) \approx 7.389 + 1.649 + 0.368 = 9.406 denominator=exp(2.0)+exp(0.5)+exp(1.0)7.389+1.649+0.368=9.406
正样本 ( P ( positive ) P(\text{positive}) P(positive) ):
P ( positive ) = exp ⁡ ( 2.0 ) denominator = 7.389 9.406 ≈ 0.785 P(\text{positive}) = \frac{\exp(2.0)}{\text{denominator}} = \frac{7.389}{9.406} \approx 0.785 P(positive)=denominatorexp(2.0)=9.4067.3890.785
负样本 ( P ( negative , j = 2 ) P(\text{negative}, j=2) P(negative,j=2) ):
P ( negative , j = 2 ) = exp ⁡ ( 0.5 ) denominator = 1.649 9.406 ≈ 0.175 P(\text{negative}, j=2) = \frac{\exp(0.5)}{\text{denominator}} = \frac{1.649}{9.406} \approx 0.175 P(negative,j=2)=denominatorexp(0.5)=9.4061.6490.175
负样本 ( P ( negative , j = 3 ) P(\text{negative}, j=3) P(negative,j=3) ):
P ( negative , j = 3 ) = exp ⁡ ( − 1.0 ) denominator = 0.368 9.406 ≈ 0.039 P(\text{negative}, j=3) = \frac{\exp(-1.0)}{\text{denominator}} = \frac{0.368}{9.406} \approx 0.039 P(negative,j=3)=denominatorexp(1.0)=9.4060.3680.039

交叉熵损失对 logits 的梯度

交叉熵损失公式:
Loss i = − log ⁡ ( P ( positive ) ) \text{Loss}_i = -\log(P(\text{positive})) Lossi=log(P(positive))
对 ( logits [ 0 , : ] \text{logits}[0, :] logits[0,:] ) 的梯度计算:
∂ Loss i ∂ logits [ 0 , j ] = P [ i , j ] − δ i , j \frac{\partial \text{Loss}_i}{\partial \text{logits}[0, j]} = P[i, j] - \delta_{i, j} logits[0,j]Lossi=P[i,j]δi,j
其中 ( δ i , j \delta_{i, j} δi,j ) 是 Kronecker delta,表示只有正样本(即对角线)位置是 1,其余为 0。

对于第一行:

  • 正样本 ( j = 0 j=0 j=0 ):
    ∂ Loss 0 ∂ logits [ 0 , 0 ] = P [ 0 , 0 ] − 1 = 0.785 − 1 = − 0.215 \frac{\partial \text{Loss}_0}{\partial \text{logits}[0, 0]} = P[0, 0] - 1 = 0.785 - 1 = -0.215 logits[0,0]Loss0=P[0,0]1=0.7851=0.215
  • 负样本 ( j = 1 j=1 j=1 ):
    ∂ Loss 0 ∂ logits [ 0 , 1 ] = P [ 0 , 1 ] = 0.175 \frac{\partial \text{Loss}_0}{\partial \text{logits}[0, 1]} = P[0, 1] = 0.175 logits[0,1]Loss0=P[0,1]=0.175
  • 负样本 ( j = 2 j=2 j=2 ):
    ∂ Loss 0 ∂ logits [ 0 , 2 ] = P [ 0 , 2 ] = 0.039 \frac{\partial \text{Loss}_0}{\partial \text{logits}[0, 2]} = P[0, 2] = 0.039 logits[0,2]Loss0=P[0,2]=0.039

Step 2: 更新 logits

使用梯度下降法更新:
logits [ i , j ] = logits [ i , j ] − η ⋅ ∂ Loss i ∂ logits [ i , j ] \text{logits}[i, j] = \text{logits}[i, j] - \eta \cdot \frac{\partial \text{Loss}_i}{\partial \text{logits}[i, j]} logits[i,j]=logits[i,j]ηlogits[i,j]Lossi

对于第一行:

  • 正样本 ( j = 0 j=0 j=0 ):
    logits [ 0 , 0 ] = 2.0 − 0.1 ⋅ ( − 0.215 ) = 2.0 + 0.0215 = 2.0215 \text{logits}[0, 0] = 2.0 - 0.1 \cdot (-0.215) = 2.0 + 0.0215 = 2.0215 logits[0,0]=2.00.1(0.215)=2.0+0.0215=2.0215
  • 负样本 ( j = 1 j=1 j=1 ):
    logits [ 0 , 1 ] = 0.5 − 0.1 ⋅ 0.175 = 0.5 − 0.0175 = 0.4825 \text{logits}[0, 1] = 0.5 - 0.1 \cdot 0.175 = 0.5 - 0.0175 = 0.4825 logits[0,1]=0.50.10.175=0.50.0175=0.4825
  • 负样本 ( j = 2 j=2 j=2 ):
    logits [ 0 , 2 ] = − 1.0 − 0.1 ⋅ 0.039 = − 1.0 − 0.0039 = − 1.0039 \text{logits}[0, 2] = -1.0 - 0.1 \cdot 0.039 = -1.0 - 0.0039 = -1.0039 logits[0,2]=1.00.10.039=1.00.0039=1.0039

更新后的第一行 logits:
logits [ 0 , : ] = [ 2.0215 , 0.4825 , − 1.0039 ] \text{logits}[0, :] = [2.0215, 0.4825, -1.0039] logits[0,:]=[2.0215,0.4825,1.0039]


Step 3: 下一轮的损失计算

使用更新后的 logits 重新计算 softmax 概率和损失:
logits [ 0 , : ] = [ 2.0215 , 0.4825 , − 1.0039 ] \text{logits}[0, :] = [2.0215, 0.4825, -1.0039] logits[0,:]=[2.0215,0.4825,1.0039]

分母:
denominator = exp ⁡ ( 2.0215 ) + exp ⁡ ( 0.4825 ) + exp ⁡ ( − 1.0039 ) ≈ 7.562 + 1.620 + 0.367 = 9.549 \text{denominator} = \exp(2.0215) + \exp(0.4825) + \exp(-1.0039) \approx 7.562 + 1.620 + 0.367 = 9.549 denominator=exp(2.0215)+exp(0.4825)+exp(1.0039)7.562+1.620+0.367=9.549
正样本概率:
P ( positive ) = exp ⁡ ( 2.0215 ) denominator = 7.562 9.549 ≈ 0.792 P(\text{positive}) = \frac{\exp(2.0215)}{\text{denominator}} = \frac{7.562}{9.549} \approx 0.792 P(positive)=denominatorexp(2.0215)=9.5497.5620.792
损失:
Loss = − log ⁡ ( P ( positive ) ) = − log ⁡ ( 0.792 ) ≈ 0.233 \text{Loss} = -\log(P(\text{positive})) = -\log(0.792) \approx 0.233 Loss=log(P(positive))=log(0.792)0.233

对比:

  • 更新前损失:( 0.34 )
  • 更新后损失:( 0.233 )

总结

通过一轮梯度下降:

  1. 正样本相似度提升(从 2.0 增加到 2.0215)。
  2. 负样本相似度下降(例如从 0.5 减少到 0.4825)。
  3. 损失函数减小,从 ( 0.34 ) 减少到 ( 0.233 ),说明模型朝着更优的方向优化。

这种优化方式有效地让正样本对更加接近,而负样本对更加远离,从而提升模型在对比学习任务中的表现。

后记

2024年12月13日21点56分于上海,在GPT4o大模型辅助下完成。


http://www.ppmy.cn/server/150071.html

相关文章

如何解决 java.lang.IndexOutOfBoundsException 异常问题?亲测有效的解决方法!

IndexOutOfBoundsException 是 Java 中常见的运行时异常,表示访问了无效的索引(数组、集合、字符串等)。本文将从原因分析到解决方法,并提供真实案例和代码示例,帮你彻底解决这个问题。 1. 问题分析 抛出 IndexOutOfB…

重生之我在异世界学编程之C语言:深入函数递归篇

大家好,这里是小编的博客频道 小编的博客:就爱学编程 很高兴在CSDN这个大家庭与大家相识,希望能在这里与大家共同进步,共同收获更好的自己!!! 函数递归与迭代 引言正文一、递归的基本概念二、递…

vue3+setup使用rtsp视频流实现实时监控,全屏,拍摄,自动拍摄等功能(纯前端)

vue3setup使用rtsp视频流实现实时监控,全屏,拍摄,自动拍摄等功能(纯前端) 概要 本文介绍了如何在Vue应用中通过WebRTC技术获取摄像头的rtsp视频流,同时展示了实时监控,全屏,拍摄,自动拍摄等功…

利用GeoWave导入矢量数据到HBase/Accumulo数据库

前言 最近在做有关地理时空大数据的实验,本文将介绍如何利用geowave框架,将矢量数据导入到HBase或Accumulo等NoSQL数据库中。 软件版本: Hadoop: 2.10.2 Zookeeper: 3.6.4 geowave: 1.2.0 Accumulo:1.9.3 HBase: 1.4.0 Ja…

MongoDB-单键索引与复合索引

在 MongoDB 中,索引是提高查询性能的一个重要手段。通过为集合中的字段创建索引,可以显著加快对数据的检索速度。MongoDB 支持多种类型的索引,其中 单键索引 和 复合索引 是最常用的两种类型。了解这两种索引的工作原理、使用场景以及区别&am…

本地体验新版springcloud-搭建工程学习笔记

为了快速体验下新版本springcloud.对照b站图灵视频简单记录下。起码入门不要钱,值得推荐。 基础知识: 会用springboot写demo。 会用mybatis操作MYSQL。 会用git拉取代码。 这都是基本操作。 环境准备: jdk17 demo是21.我实际测试17也可…

Nginx之配置防盗链(Configuring Anti-hotlinking in Nginx)

运维小白入门——Nginx配置防盗 什么是防盗链: 防盗链技术主要用于防止未经授权的第三方或域名访问网站的静态资源。例如,一个网站可能拥有独特的图片素材,为了防止其他网站通过直接链接图片URL的方式访问这些图片,网站管理员会采…

MedLSAM: 用于3D CT图像的局部化和分割模型|文献速递-生成式模型与transformer在医学影像中的应用

Title 题目 MedLSAM: Localize and segment anything model for 3D CT images MedLSAM: 用于3D CT图像的局部化和分割模型 01 文献速递介绍 最近,计算机视觉领域对开发大规模的基础模型的兴趣不断增加,这些模型能够同时处理多个视觉任务&#xff0c…