VLM--CLIP作分类任务的损失函数

devtools/2024/12/27 1:42:33/

info_nce_loss

这个是clip作对比学习的损失函数
各个博客上都有详细介绍了,我这里就不赘述

def info_nce_loss(image_features, text_features,logit_scale,labels, temperature=0.07):batch_size = image_features.shape[0]image_features = image_features / image_features.norm(dim=-1, keepdim=True)text_features = text_features / text_features.norm(dim=-1, keepdim=True)similarity_matrix = torch.matmul(image_features, text_features.T) / temperaturelogits_per_image = similarity_matrixlogits_per_text = similarity_matrix.T# 构造标签,正样本对应的位置为1,其余为0,这里假设批次内第一个文本特征是对应图像的正样本文本特征gen_labels = torch.arange(batch_size).long().to(image_features.device)total_loss = (F.cross_entropy(logits_per_image, gen_labels)+F.cross_entropy(logits_per_text, gen_labels))/2return total_loss, logits_per_image, logits_per_text

我踩的坑

微调 c l i p clip clip分类任务类别数为3

  1. 数据集为图像-文本对数据集:即一个数据样本为一个图像和对应的文本在json文件里。这里每个类别的图像的文本都是一样的,也就是a类别下图像可能会有细微不同,但是文本都是一样的
  2. 微调 c l i p clip clip 的结构同原始 c l i p clip clip 一致,输出的图像特征维度为 [ 输入图像数量 , 512 ] [输入图像数量,512] [输入图像数量,512],文本特征维度为 [ 输入的文本数量 , 512 ] [输入的文本数量,512] [输入的文本数量,512]这里选用不同的clip结构,输出维度可能有所不同
  3. 我微调过程输入 c l i p clip clip 的数据为 b a t c h _ s i z e batch\_size batch_size个图像、文本。输出的logit维度为 [ b a t c h _ s i z e , b a t c h _ s i z e ] [batch\_size,batch\_size] [batch_size,batch_size]

当使用 c l i p clip clip 去做分类任务假设类别为3时,直接使用上面的损失函数并不合适
因为:
g e n _ l a b e l s gen\_labels gen_labels会产生一个 [ 0 … … b a t c h _ s i z e − 1 ] [0……batch\_size-1] [0……batch_size1]的序列,接着和 l o g i t logit logit 做交叉熵。这里的 l o g i t logit logit 维度为 [ b a t c h _ s i z e , b a t c h _ s i z e ] [batch\_size,batch\_size] [batch_size,batch_size]

这意味着: l o g i t logit logit的对角线处的数据才会被 l o s s loss loss记录即第 i i i 个图像和第 i i i 个文本才是匹配的正样本,其余的为负样本。

这跟我实验设置下的分类任务有所冲突:因为我只有3个类别,而对于 l o g i t logit logit的第 i i i 行(即第 i i i 图像),只会跟第 i i i 列(即第 i i i 个文本)是正样本,而第 i i i 个图像应该和不止一个文本是正样本。例如:第0行图像和第0列的文本是正样本,还会和第 0 + 3 i , i = 0 , 1 , 2 … … 0+3i,i=0,1,2…… 0+3ii=0,1,2……列的文本是正样本,而 i n f o _ n c e _ l o s s info\_nce\_loss info_nce_loss会忽略掉后面的正样本

导致微调出来的 A C C ACC ACC F 1 F1 F1 都比较低

clip选用这样的损失函数,是因为其并不是做分类任务,而是直接用海量的互联网数据去预训练(a类别下图像可能会有细微不同,但是文本都是一样的这个情况存在的可能性小)

在这里插入图片描述

clip分类任务损失函数

def info_nce_loss(image_features, text_features,logit_scale,labels, temperature=0.07):"""计算InfoNCE损失函数,模拟CLIP中的对比学习损失计算参数:image_features (torch.Tensor): 图像特征表示,形状为 [batch_size, feature_dim]text_features (torch.Tensor): 文本特征表示,形状为 [batch_size, feature_dim]temperature (float): 用于缩放相似度得分的温度参数,控制分布的平滑程度返回:loss (torch.Tensor): InfoNCE损失值"""batch_size = image_features.shape[0]image_features = image_features / image_features.norm(dim=-1, keepdim=True)text_features = text_features / text_features.norm(dim=-1, keepdim=True)similarity_matrix = torch.matmul(image_features, text_features.T) / temperaturelogits_per_image = similarity_matrixlogits_per_text = similarity_matrix.Tgen_labels = labelstotal_loss = F.cross_entropy(logits_per_image, gen_labels)return total_loss, logits_per_image, logits_per_text
  1. 给每个图像-文本对记录类别 l a b e l label label
  2. 改变文本输入,每个 b a t c h _ s i z e batch\_size batch_size下输入的文本维度为 [ n _ c l a s s , ] [n\_class,] [n_class,],经过 c l i p _ e n c o d e r clip\_encoder clip_encoder 后维度为 [ n _ c l a s s , 512 ] [n\_class,512] [n_class,512]
  3. 接着做交叉熵计算

http://www.ppmy.cn/devtools/145677.html

相关文章

【大选】2024年美国总统选举数据分析可视化

前言 • 👓 可视化主要使用 Plotly • 🔎 数据处理主要使用 pandas • 👉 本文是我自己在和鲸社区的原创 1.项目背景描述 2024年美国大选是该国政治生活中的重要事件,吸引了全球的关注。本报告通过对选举数据的分析&#xff0c…

Java 网络编程 ②-TCP Socket

这里是Themberfue 在上一节中,我们简单认识了 TCP协议 和 UDP协议 以及 基于UDP Socket 编写了简单的网络通信代码 本节我们将基于 TCP Socket 编写简单的网络通信代码 TCP Socket 类似 UDP Socket,Java也基于 TCP协议 进行了一些接口的封装 基于 TCP 封…

[计算机图形学] 【Unity Shader】【图形渲染】Shader数学基础6-逆矩阵与正交矩阵

在计算机图形学与Shader编程中,矩阵广泛应用于各种变换操作,如旋转、缩放、平移等。理解矩阵的基本性质,尤其是逆矩阵和正交矩阵,对于有效地实现图形变换至关重要。本文将介绍逆矩阵和正交矩阵的数学基础,帮助你更好地理解这些概念及其在图形学中的应用。 逆矩阵的基本概…

玄机-第一章 应急响应-webshell查杀

靶机账号密码 root xjwebshell 1.黑客webshell里面的flag flag{xxxxx-xxxx-xxxx-xxxx-xxxx} 2.黑客使用的什么工具的shell github地址的md5 flag{md5} 3.黑客隐藏shell的完整路径的md5 flag{md5} 注 : /xxx/xxx/xxx/xxx/xxx.xxx 4.黑客免杀马完整路径 md5 flag{md5} 1.黑客web…

React第十八节 useEffect 用法使用技巧注意事项详解

1、概述 useEffect 是React中一个用于 将组件与外部系统同步的 Hook;在函数式组件中处理副作用函数的 Hook,用于替代类式组件中的生命周期函数; 可以在副作用函数中 实现以下操作: a、请求接口,获取后台提供数据 b、操…

【漏洞复现】F5 BIG-IP Next Central Manager SQL注入漏洞(CVE-2024-26026)

🏘️个人主页: 点燃银河尽头的篝火(●’◡’●) 如果文章有帮到你的话记得点赞👍+收藏💗支持一下哦 一、漏洞概述 1.1漏洞简介 漏洞名称:F5 BIG-IP Next Central Manager SQL注入漏洞漏洞编号:CVE-2024-26026漏洞威胁等级:超危影响范围:BIG-IP Next Central Manage…

基于CNN-BiLSTM-selfAttention混合神经网络的多分类预测【MATLAB】

在深度学习中,不同神经网络架构的组合往往可以实现更强大的表现。将卷积神经网络(CNN)、双向长短期记忆网络(BiLSTM)和自注意力机制(Self-Attention)结合在一起,可以充分发挥三者的优…

springBoot发布https服务及调用

一、服务端发布https服务 1、准备SSL证书 (1)自签名证书:如果你只是用于开发或测试环境,可以生成一个自签名证书。 (2)CA 签名证书:对于生产环境,应该使用由受信任的证书颁发机构 …