【提示学习】PromptSync论文问题汇总

news/2024/10/22 15:29:35/

文章目录

  • PromptSync: Bridging Domain Gaps in Vision-Language Models through Class-Aware Prototype Alignment and Discrimination(2024CVPR)
  • 1 Introduction
  • 2 Related Work
    • 2.1 CLIP
    • 2.2 TPT
  • 3 Methodology
    • 3.1 提出方法PromptSync
    • 3.2 类感知原型生成(视觉原型?语言原型?)
        • Q:为什么不包括class token?
    • 3.3 原型判别损失
      • 3.3.1 正对样本损失
      • 3.3.2 负对样本损失
        • Q:为什么不用计算𝑐𝑘的增强视图和其他所有类别增强视图的相似度?
      • 3.3.3 最终的优化目标
    • 3.4 原型对齐损失
      • 3.4.1 振幅对齐损失
      • 3.4.2 角度对齐损失
      • 3.4.3 合并
    • 3.5 算法的细节
      • 3.5.1 计算原型判别损失
      • 3.5.2 测试时间适应过程
      • 3.5.3 多次迭代更新
  • 4 实验
    • 4.1 baseline对比
    • 4.2 实施细节
    • 4.3 领域泛化
    • 4.4 Base to Novel
    • 4.5 跨数据集转移性能
  • 5 消融实验
  • 6 性能和延迟
  • 7 敏感性比较
  • 8 LAION400M代理数据集分析
  • 9 Conclusion
  • TPT(2022 NeurIPS)
  • PromptAlign(2023 NeurIPS)

PromptSync: Bridging Domain Gaps in Vision-Language Models through Class-Aware Prototype Alignment and Discrimination(2024CVPR)

  • 提出类别级的原型对齐方法,将每个测试样本与源分布对齐,减轻类间分布迁移的影响
  • 我们在文本和视觉分支上都进行了提示调整
  • 将测试样本的原型与预先计算的类原型对齐
  • 按照从增强视图中获得的每个类的平均概率加权来调整可学习的提示令牌

1 Introduction

  • 提出了一种面向类别的原型对齐技术,用于对齐每个测试样本的上下文与类别源分布基础上,从而减轻类别之间的分布偏移效应。
  • 提出了面向类别的原型判别,以发现有效对齐的类别分布。此外,我们还提出了从代理源数据集进行类别原型的离线计算,用于基础V-L模型。
  • 提出了针对文本和视觉分支的多模态测试时提示调整。基于从基础到新颖的泛化、领域泛化以及跨数据集转移的实证评估显示了我们方法的效率高于现有方法。
    在这里插入图片描述

2 Related Work

clip里面取max,PromptSync变成了取平均?

2.1 CLIP

在这里插入图片描述

  • Clip测试阶段,图像特征与文本特征做余弦相似度计算,相似度最大的即为对应的类别。

2.2 TPT

在这里插入图片描述
在这里插入图片描述

  • 在过滤后的增强视图上,模型产生的向量类概率的平均值,即为平均类概率,平均类概率作为权重,对齐类原型与过滤增强试图。

3 Methodology

3.1 提出方法PromptSync

3.2 类感知原型生成(视觉原型?语言原型?)

代理数据集:用于训练模型的数据集,在本文中指定了代理数据集
原型:对于每个类别原型,定义为该类别所有样本特征向量的平均值
生成类感知原型:

在这里插入图片描述

  • h x t h_x^t hxt:样本x在文本t上的原型向量
  • h x v h_x^v hxv:样本x在视觉v上的原型向量
  • h C L S , x v h_{CLS,x}^v hCLS,xv:样本x在视觉v上 [CLS] token的原型向量
  • ET (x, ei):样本x的第i个token在文本编码器T的输出
  • EV (x, ei):样本x的第i个token在图像编码器V的输出
  • P=所有tokens的数量(包括可学习、不可学习、文本、图像)(不包括SOS、EOS、CLS)
  • token:文本数据中的基本单元,通常是一个词或一个字符,每个token都会被映射成一个对应的向量表示,向量表示了token的语义信息。

在这里插入图片描述

Q:为什么不包括class token?

在文本原型计算时,每个类别计算都去掉了SOS、EOS、CLS,用的是(t1、t2、…、tL),那计算出来的文本原型,都是一样的?

3.3 原型判别损失

训练可学习提示,使用对比学习的方法,拉近同一类别样本在嵌入空间中的距离,将不同类别的样本推开,实现更好的样本分类和原型分布

3.3.1 正对样本损失

在这里插入图片描述

L p o s ( c k ) \mathcal{L}_{pos}(c_k) Lpos(ck) :正对样本positive的损失,拉近同类别原型和增强视图
计算了每个增强视图𝑎𝑢𝑔与类别 𝑐𝑘的原型向量 ℎ𝑐𝑘𝑚之间的相似度,将相似度值取指数,进行加权平均

3.3.2 负对样本损失

在这里插入图片描述

L n e g ( c k ) \mathcal{L}_{neg}(c_k) Lneg(ck) :负对样本negative的损失,推开不同类别原型和增强视图
分成三部分

  • 𝑐𝑘原型向量和其他所有类别原型向量hcm的相似度
  • 𝑐𝑘的增强视图和其他所有类别的原型向量hcm的相似度
  • 𝑐𝑘原型向量和其他所有类别增强视图的相似度
Q:为什么不用计算𝑐𝑘的增强视图和其他所有类别增强视图的相似度?

3.3.3 最终的优化目标

在这里插入图片描述
在这里插入图片描述

L D \mathcal{L}_{D} LD :正对样本损失和负对样本损失的比率的负对数,即最终的优化目标

  • 最小化ld,即为最大化求和的部分
  • 最大化lpos(拉近本身与增强图像的相似度)
  • 最小化lneg(减小本身与其他类别的相似性)

3.4 原型对齐损失

  • Ld能够有效区分不同的类别,但无法调整测试样本的提示
  • 提出测试样本及其增强视图,与源分布中类原型的对齐
  • 对于每个测试样本𝑥𝑖,以及每个类别𝑐,计算测试样本𝑥𝑖的原型
  • p x i m p_{xi}^m pxim与类别𝑐的类原型 𝑝𝑐𝑚之间的振幅对齐损失和角度对齐损失
  • pˆp[c] :测试样本最可能的类别,均值概率,作为LA的权重,作者后面会讲到

3.4.1 振幅对齐损失

测试样本的原型与类原型之间的距离
在这里插入图片描述

3.4.2 角度对齐损失

测试样本的原型与类原型之间的角度相似度
在这里插入图片描述
我们要最大化他们的角度相似度,因此最大化L’ang

3.4.3 合并

在计算损失时,均方误差损失对于一定范围内的误差增加会给予相等的惩罚,而我们希望在小范围内的误差增加时给予更大的惩罚,因此作者将损失取对数。
在这里插入图片描述

在这里插入图片描述
其中,最大化角度相似度,因此最大化L’ang,最小化Lang
在这里插入图片描述

3.5 算法的细节

3.5.1 计算原型判别损失

在源数据集上计算原型判别损失需要使用 CLIP 模型的预训练数据集,CLIP 模型是在超过 4 亿个图像文本对上进行训练的,数据不公开可用。因此,为了近似源数据集,作者选择使用了 ImageNet 数据集。在 ImageNet 上计算出每个类别的原型,这些原型是离线计算的,包括了样本和其增强视图。

3.5.2 测试时间适应过程

在每次迭代的测试中

  • 元训练阶段:使用原判别目标函数LD进行训练,计算梯度,得到更新后的提示
  • 元测试阶段:使用更新后的提示,设置置信度阈值,过滤增强视图的预测概率,计算在F上的均值概率p,并作为LA中的权重。计算梯度。
  • 计算梯度平均值,使用组合目标更新提示
    在这里插入图片描述
    在这里插入图片描述
    在这里插入图片描述

3.5.3 多次迭代更新

n>1时,会累计平均梯度,然后进行最终的提示更新

4 实验

数据集:

  • 作者在ImageNetV2、ImageNet-Sketch、ImageNet-A 和 ImageNet-R进行评估
  • 还考虑了Photorealistic Unreal Graphics (PUG) 数据集(包括不同的纹理、大小、方向和背景)
  • 对于跨数据集转移设置,作者考虑了10个不同的图像分类数据集,包括 Caltech 101、StanfordCars、Food101、Flowers102、FGVC-Aircraft、OxfordPets、SUN397、DTD、UCF101 和 EUROSAT

4.1 baseline对比

包括 CoOp、CoCoOp、TPT 、 PromptAlign、MaPLe

4.2 实施细节

  • 在单个 NVIDIA A100 40GB GPU 上运行了所有实验
  • 在 ImageNet 上进行了训练,使用随机选择的 16 张图像作为每个类别的训练数据
  • 使用 2 个提示标记进行 3 层深度的训练
  • 图像增强:使用随机裁剪、背景替换、水平翻转增强和视觉损坏,对每个测试图像进行了 127 个不同视图的增强
  • 文本增强:作者使用了 WordNet 中的同义词、反义词和部分词

4.3 领域泛化

在这里插入图片描述

表1,对比了各种方法在不同数据集上的性能,平均值表示了对所有领域的平均性能。

表2中,着重比较了在领域泛化设置下针对分布对齐的性能,具体指标包括相机姿态、姿势、尺度、纹理、光照和世界。

4.4 Base to Novel

在这里插入图片描述

MaPLE+TPT后部分会下降

4.5 跨数据集转移性能

在这里插入图片描述

5 消融实验

在这里插入图片描述

表5,熵损失、对齐损失、判别损失的消融实验
在这里插入图片描述
表6,对齐损失的消融实验

6 性能和延迟

在这里插入图片描述
延迟:单个提示更新的时间(小时)
PromptSync*变体展示了更快的处理时间,而性能仅略有下降。这个结果强调了通过原型对齐实现的泛化。

7 敏感性比较

在这里插入图片描述
图2a,随着增强视图数量的增加,准确率上升
图2b,准确率随着提示更新步次数的增加而提高

8 LAION400M代理数据集分析

我们选择ImageNet作为可行的代理源数据集,使用LAION400M的子集

9 Conclusion

总之,PromptSync显著改善了视觉语言模型中的zero-shot泛化。我们的方法解决了类优势和方差问题,总体上比现有方法高出2.33%,在领域泛化基准上,从基础到新的泛化提高了1%,跨数据集传输提高了2.84%。这强调了PromptSync在增强视觉语言模型稳健性方面的有效性。

TPT(2022 NeurIPS)

imagenet里面没有的类别,怎么对齐?

PromptAlign(2023 NeurIPS)

多模态测试时间提示调优方法
将视觉分支中测试样本的令牌分布与完整代理源数据集的预计算统计数据对齐,而不考虑一个类分布可能具有与其他类不同的均值和方差。


http://www.ppmy.cn/news/1439680.html

相关文章

AMBA-CHI协议详解(二)

《AMBA 5 CHI Architecture Specification》 文章目录 2.1 Channels综述2.2 Channel域段2.2.1 request fields2.2.2 Response fields2.2.3 Snoop request fields2.2.4 Data fields 2.3 事务结构2.3.1 Read transactions2.3.1.1 Allocating Read2.3.1.2 Non-allocating Read 2.…

LeetCode 0039.组合总和:回溯 + 剪枝

【LetMeFly】39.组合总和:回溯 剪枝 力扣题目链接:https://leetcode.cn/problems/combination-sum/ 给你一个 无重复元素 的整数数组 candidates 和一个目标整数 target ,找出 candidates 中可以使数字和为目标数 target 的 所有 不同组合…

【后端学习笔记·Golang】手机短信验证

文章目录 手机号码验证前置准备开通阿里云sms服务获取AccessKey并下载sdk 生成随机验证码将验证码发送到用户手机接口发送验证码校验验证码 手机号码验证 流程: 接收用户请求后生成随机验证码,并将验证码存入Redis中,并设置TTL通过阿里云sd…

Python问题(报错)大全 - 不定时更新

问题汇总 1.打包问题1.1打包后报错1.1.1 Pgzero1.1.1.1KeyError: "No images directory found to load image img.png."[36616] Failed to execute script file due to unhandled exception!1.1.1.1.1 问题描述1.1.1.1.2 问题解释1.1.1.1.3 问题解决方案 2.常规运行问…

为什么 Facebook 不使用 Git?

在编程的世界里,Git 就像水一样常见,以至于我们认为它是创建和管理代码更改的唯一可行的工具。 前 Facebook 员工,2024 年 首先,我为什么关心? 我致力于构建 Graphite,它从根本上受到 Facebook 内部工具的…

scss基础和css扩展

变量 定义变量 //app.scss $allpadding:20px; //声明颜色变量 $color//使用 import /assets/app.scss;.container{width: 100%;padding:$allpadding;} ⚠️scss中,中下划线和下划线是同一个东西 $link-color: blue; a {color: $link_color; }//编译后a {color: …

Go语言中,两个比较流行的缓存库

在 Go 中实现带有过期时间的缓存通常需要一个可以自动处理键值过期的缓存系统。虽然标准库中没有直接提供这种功能,但有几个流行的第三方库可以很好地满足这一需求。下面我会介绍两个比较流行的 Go 缓存库:go-cache 和 bigcache。 1. go-cache go-cache…

DaPy:实现数据分析与处理

DaPy:实现数据分析与处理 DaPy是一个用于数据分析和处理的Python库,它提供了一系列强大的工具和功能,使开发者能够高效地进行数据清洗、转换和分析。本文将深入解析DaPy库的特点、功能以及使用示例,帮助读者了解如何利用DaPy库处理…