多标签分类SOTA | ADDS论文解读

news/2024/12/12 16:50:45/
  • 论文标题:Open Vocabulary Multi-Label Classification with Dual-Modal Decoder on Aligned Visual-Textual Features
  • 论文传送门:https://arxiv.org/pdf/2208.09562

  • paperwithcode多标签分类排名:

这篇文章主要用来解决多标签分类问题,具体针对未见过的标签进行分类。这篇文章提出一种用于开放词汇多标签分类任务的新算法——对齐双模态分类器(Aligned Dual moDal ClaSsifier, ADDS)。设计了一种简单而有效的金字塔转发方法来提高高分辨率输入的性能,并采用了选择性语言监督,进一步提高了模型的性能。

在多个标准基准(NUS-WIDE、ImageNet-1k、ImageNet-21k和MS-COCO)上进行的大量实验表明,我们的方法显著优于以前的方法,并为开放词汇表多标签分类、传统多标签分类和极端情况下的单标签到多标签分类提供了最先进的性能,其中模型在单标签数据集(ImageNet-1k、ImageNet-21k)在多标签(MS-COCO和NUS-WIDE)上进行了测试。

一、引言

传统分标签分类会遇到的问题:不能满足一些实际应用的需要,因为在测试过程中可能会出现看不见的标签。以前,带有不可见标签的任务通常被称为多标签零样本学习,相关方法通常是基于标签相关性创建的,这些方法试图识别图像内标签之间的潜在关系,以方便分类,通常是通过构建标签图。

对于最先进的(SOTA)方法ML-Decoder[41],仅从单词中创建标签嵌入,或为每个标签分配可学习的嵌入,而不是间接地从图关系中创建和学习,从而允许更复杂的结果。最常用的标签嵌入之一是Word2Vec[34],它是通过使用外部文本数据集的预训练任务创建的(在预训练期间可以看到目标标签类),然后根据不同的标签嵌入提取局部判别特征,输出每类概率。然而,缺点:

  1. 这些模型只关注单词之间的联系,并且将学习到的映射(从图像到可见标签)推广到目标映射(从图像到未见标签)仍然具有挑战性

  2. 单词嵌入方法会阻碍了模型处理词组或者句子标签,这在实践中也存在,并且非常具有挑战性。

当前已经有论文引入Open-Vocabulary设置,它是零样本和弱监督设置的推广,更适合于处理看不见的类。虽然目标类在训练过程中是未知的,但它可以是预训练任务中整个语言词汇的任何子集(例如,图像标题数据集的对比学习)。利用在图像标题数据集上训练的视觉语言预训练(VLP)模型来帮助建立视觉和文本嵌入之间的联系,从而在算法设计上提供更大的灵活性,而不是在分类数据集上使用昂贵的注释。

该论文提出了一种基于对齐视觉和文本嵌入的开放词汇多标签分类框架(对齐双模态分类器)。该框架包括一种新颖的DMdecoder (dual - modal decoder)设计,它利用双模态通过逐步融合视觉嵌入与文本信息和开发更丰富的语义理解来增强变压器解码器。它还包括一个金字塔转发方法,使模型在低分辨率图像上预训练到高分辨率图像,而无需重新训练。总的来说,我们的工作做出了以下技术贡献:

  • 我们开发了一个开放词汇表的多标签分类框架,它建立在对齐的视觉和文本特征上。该框架包括DM-Decoder和Pyramid-Forwarding。DM-Decoder是一种促进双模信息源语义融合的新型变压器解码器,Pyramid-Forwarding是一种新的自适应方法,可以处理比训练图像更高分辨率的图像,并且可以大大降低视觉转换器的计算成本。

  • 我们在各种多标签分类任务中进行了广泛的实验。我们的add框架在所有场景下都明显优于以前的SOTA方法,包括开放词汇多标签分类(在NUS-WIDE上提高11.57分),单到多标签分类(从ImageNet-1k到MS-COCO和NUS-WIDE分别提高24.71分和16.49分),以及传统的多标签分类(例如,在MS-COCO上提高2.14分)。

二、相关工作

2.1 常规多标签分类

传统的方法分为两组:

  1. 基于感兴趣的区域:通过定位图像中的每个对象或捕获注意图,然后对其进行单标签分类来解决的。然而,这些方法往往存在发现区域粗糙、计算成本高、某些概念或场景难以定位、某些区域包含重复概念等问题。

  2. 基于标签相关性:探寻训练图像中标签之间的潜在关系,将特征表示分解为特定于类别语义的表示,并应用图神经网络来探索它们之间的相互作用。

2.2 多标签零样本分类

对于零样本分类也就是测试或者评估数据的类别没有出现在训练数据中,从3个方法角度来解释这个问题:

  1. 探索标签之间的联系来捕捉看不见的标签:认为每个类都是嵌入的属性向量的空间,然后引入一个度量图像与标签嵌入之间兼容性的函数来确定正确的类

  2. 估计图像的主方向来研究图像-词相关性:基于给定图像的相关标签的词向量在词向量空间中沿主方向可以排在不相关标签之前的假设;

  3. 依靠文本特征之间的关系进行学习:典型代表ml-decoder,但由于使用Word2Vec生成标签文本嵌入,文本嵌入学习过程中缺乏对视觉信息的监督,学习到的图像与文本之间的映射很难推广到不可见的数据空间。

2.3 视觉语言预训练(VLP)

视觉语言预训练通过对不同任务的大规模数据进行预训练来学习图像和语言之间的语义对应关系。在本文中,我们主要通过CLIP[38]来保持视觉和文本嵌入的对齐,CLIP是建立在图像和文本嵌入对之间的余弦相似度的基础上,并在一个大而有噪声的数据集上进行训练。

2.4 开放性词汇学习

VLP模型通过学习大规模的训练语料库,实现了图像与相应文本信息之间的强连接。所谓“开放性”就是训练后的模型对图片进行分类时,可选的分类候选文本可以是任意文本集合,预测的文本和训练的文本不管是内容或者顺序都无需一模一样。

三、方法论

3.1 综述

详细的结构如下:

受VLP的启发,我们在CLIP预训练模型的帮助下构建了视觉语义对齐。具体来说,我们采用视觉变压器(vision transformer, ViT)[9]网络架构作为图像编码器模型,采用多层变压器作为文本编码器模型,两个编码器的参数均来自CLIP。在训练期间,他们都被冻结以保持对齐(如果解冻可能会导致更糟糕的结果)

图1。我们用于多标签分类的ADDS框架概述。带有提示的文本标签被输入到文本塔中以获得文本嵌入。图像首先由金字塔转发模块处理,然后输入到图像塔中获得视觉嵌入,视觉嵌入与文本嵌入对齐并堆叠在令牌大小维度上。然后通过文本嵌入的初始查询和视觉嵌入的初始键/值,通过六层DM解码器将文本嵌入(经过选择性语言监督模块)和堆叠的视觉嵌入融合在一起。在所有标签之间共享映射后,网络输出每个标签类的概率。

3.2 Dual-Modal解码器

具体的模块结构:

与ml-decoder比较后发现,最后三层是新加的。

(截图来自于ml-decoder源码:点击跳转)

目的在于解决:

  • 在堆叠超过3个解码器层后,模型性能通常会下降。

  • 从视觉嵌入来看,key和value输入总是相同的。由于交叉注意层的输出是其值输入的加权和,因此不同级别的解码器层的输出实际上处于相同(或接近)的语义级别,并且都来自相同的视觉嵌入。

3.3 金字塔转发

从图示中也可以看得出来,对下采样的图像进行不同程度的切小图,然后形成金字塔特征,再用图片特征塔提取特征并进行编码。这样做的主要目的是:

可以在高分辨率图像上部署从低分辨率图像训练的预训练模型,而无需重新训练。并且提高训练效率。

(具体过程略,感兴趣请看原论文)

3.4 选择性语言监督

四、实验结果

4.1 实验超参

pytorch: 1.10.2
ubuntu: 18.04.6
GPU: NVIDIA V100 GPUs服务器集群
数据集:NUS-WIDE、MS-COCO、ImageNet-1k和ImageNet-21k
优化器:Adam
输入图像分辨率:224 × 224和336 × 336 | 学习率:3 × 10−4
输入图像分辨率:448×448、640×640、1344×1344 | 学习率:1 × 10−4
batch size: 56
max epoch: 40
weight delay: 10-4
loss function: ASL

4.2 开放词汇多标签分类

在开放词汇多标签分类中,我们首先在表1中展示了我们在NUS-WIDE数据集上的结果。

其中f代表冻结骨干网络,uf代表解冻骨干网络。zsl代表零样本学习,ov代表开放性词汇设置。k=3的意思为topk

此外,表2显示了在MSCOCO数据集上的实验结果。我们按照以下方式进行数据分割:将类名按字母递增顺序排序后,选择前65个类为可见类,其余15个类为不可见类。

4.3 单标签到多标签分类

在单标签ImageNet-1k数据集上进行训练,并在多标签MSCOCO和NUS-WIDE数据集上进行测试。

4.4 额外的实验

常规多标签分类

除了开放词汇表的多标签分类,我们也很好奇我们的模型如何在传统的多标签分类上工作。我们在MS-COCO数据集上进行了实验,结果如表5所示。

消融性研究:dm解码器的有效性

消融研究:全金字塔转发 vs 单层金字塔转发

通过比较仅使用单层金字塔转发与在MS-COCO上使用完整金字塔转发的结果来评估金字塔转发的有效性。我们选择这个数据集是因为它具有更高的分辨率,更适合在不同分辨率上进行比较。在表7中,第二列显示了Pyramid-Forwarding中的级别索引。第一行显示了在没有金字塔转发的336 × 336图像上使用该模型的结果。第二行显示了1344×1344图像上的金字塔转发模型,但只有具有最高分辨率级别(第三级),切割成16块。第三行显示了在1344×1344分辨率图像上具有完整金字塔转发的模型。我们可以看到,使用完整的金字塔转发并不能提供更多的性能提升。

消融研究:训练类数量的影响

为了公平比较,我们在ImageNet-21k中过滤掉了与NUS-WIDE重叠的类,在ImageNet-21k中选择了前15k个没有与NUS-WIDE重叠的类。我们为每个类选择100张图像,这样所选择的数据集包含1.3M张图像,与ImageNet-1k (1.3M)的水平相同。

其他VLP模型的实验

最后,我们也很好奇其他VLP模型在我们的方法下表现如何。我们考虑两个例子:BLIP[26]和SLIP[35]。它们都具有类似CLIP的对比损失,以确保视觉和文本特征之间的对齐。在表9的第二行中,使用ViT-L图像编码器的BLIP模型在224 × 224图像上显示出35.15%的良好效果。第三行中的SLIP显示了类似的性能。然而,当我们使用在MS-COCO上进行微调的BLIP模型而没有对比度损失以确保对准时,mAP迅速下降到2.52%。这有力地表明,视觉和文本嵌入之间的相关性在提供性能提升方面起着重要作用,而不是图像或文本编码器本身。

五、结论

本文提出了一种新的开放词汇多标签分类框架——对齐双模态分类器。该框架基于文本视觉对齐,并利用了一种新颖的双峰解码器设计和金字塔转发技术。该方法在开放词汇多标签分类、单到多标签分类以及传统的多标签分类任务上都具有显著的优势。


http://www.ppmy.cn/news/1554543.html

相关文章

Django Fixtures 使用指南:JSON 格式详解

在Django开发中,fixtures是一种非常有用的工具,它们可以帮助我们序列化数据库内容,并在不同的环境或测试中重用这些数据。本文将详细介绍Django fixtures的概念、如何生成和使用JSON格式的fixtures。 什么是Fixtures? Fixtures是…

如何使用 Python 发送 HTTP 请求?

在Python中发送HTTP请求最常用的库是requests,它提供了简单易用的API来发送各种类型的HTTP请求。 除此之外,还有标准库中的http.client(以前叫做httplib)和urllib,但它们相对更底层,代码量较大&#xff0c…

vue地址解析+虚拟手机号解析

&#xff08;1&#xff09;安装 address-parse模块 npm install address-parse --save &#xff08;2&#xff09;地址修改-弹窗页面 <template><div><el-dialog title"修改收货地址" :visible.sync"dialogVisible" width"45%"…

群控系统服务端开发模式-应用开发-登录退出发送邮件

一、登录成功发送邮件 在根目录下app文件夹下controller文件夹下common文件夹下&#xff0c;修改Login.php&#xff0c;代码如下 <?php /*** 登录退出操作* User: 龙哥三年风水* Date: 2024/10/29* Time: 15:53*/ namespace app\controller\common; use app\controller\Em…

单元测试SpringBoot

添加测试专用属性 加载测试专用bean Web环境模拟测试 数据层测试回滚 测试用例数据设定

Rust快速入门(二)

三个指令&#xff1a; cargo run 执行 --release&#xff1a; 由于使用run命令rust默认为debug模式&#xff0c;代码中很多debug数据就会打印&#xff0c;于是我们使用relsase参数就可以不输出debug的代码。 cargo check 校验是否能够通过编译 cargo build 打包为可执行文件 …

【jvm】GC Roots有哪些

目录 1. 说明2. 虚拟机栈&#xff08;栈帧中的局部变量表&#xff09;中的引用3. 方法区中的类静态属性引用4. 本地方法栈&#xff08;Native方法栈&#xff09;中JNI&#xff08;Java Native Interface&#xff09;的引用5. 活跃线程&#xff08;Active Threads&#xff09;6.…

Scala的正则表达式(1)

package hfd //正则表达式的应用场景 //1.查找 findAllin //2.验证 matches //3.替换//验证用户名十分合法 //规则&#xff1a; //1.长度在6-12之间 //2.不能数字开头 //3.只能包含数字&#xff0c;大小写字母&#xff0c;下划线 object Test36 {def main(args: Array[String])…