【LLM论文日更】| 俄罗斯套娃嵌入模型

server/2024/9/23 5:37:13/
  • 论文:https://proceedings.neurips.cc/paper_files/paper/2022/file/c32319f4868da7613d78af9993100e42-Paper-Conference.pdf
  • 代码:GitHub - RAIVNLab/MRL: Code repository for the paper - "Matryoshka Representation Learning"
  • 机构:McGill University, Mila ServiceNow Research ,Facebook CIFAR AI Chair
  • 领域:embedding model
  • 发表:NeurIPS 2022

研究背景

  1. 研究问题:这篇文章要解决的问题是如何设计一种灵活的表示学习方法,使其能够适应多个下游任务,并且能够根据任务的计算资源需求进行调整。
  2. 研究难点:该问题的研究难点包括:现有固定容量的表示在学习新任务时可能过度或不足;如何在保持准确性的前提下,显著减少表示的大小和计算成本;如何扩展表示学习方法以适应不同模态(如视觉、语言)和数据规模(如网页规模)。
  3. 相关工作:该问题的研究相关工作包括大规模数据集上的通用表示学习(如ImageNet和JFT),对比学习(如Contrastive Learning),以及自然语言处理中的预训练模型(如BERT)。这些工作通常依赖于独立的低维模型、子网络优化或后处理压缩来实现表示的灵活性,但这些方法在训练/维护开销、多次前向传播、存储和内存成本等方面存在不足。

研究方法

这篇论文提出了Matryoshka Representation Learning(MRL)用于解决表示学习中的灵活性问题。具体来说,

  1. 多粒度表示:MRL通过显式优化嵌套的O(log(d))个低维向量,在高维向量中捕获多粒度信息。每个嵌入的前几个维度是一个信息丰富的低维向量,随着维度的增加,表示逐渐变得粗糙。

优化目标:MRL的目标是学习一个d维表示向量z∈Rd,使得每个嵌套维度m∈M都能独立地作为数据点x的可迁移通用表示。优化目标是使用标准经验风险最小化方法,通过单独的线性分类器来优化每个嵌套维度的多类分类损失。

其中,L是多类softmax交叉熵损失函数,cm​是相对重要性权重。
3. 高效实现:为了提高效率,MRL采用了权重绑定技术,即所有线性分类器的权重相同,从而减少内存成本。这种变体称为Efficient Matryoshka Representation Learning(MRL-E)。

实现代码为:
 

class MRL_Linear_Layer(nn.Module):def __init__(self, nesting_list: List, num_classes=1000, efficient=False, **kwargs):super(MRL_Linear_Layer, self).__init__()self.nesting_list = nesting_listself.num_classes = num_classes # Number of classes for classificationself.efficient = efficientif self.efficient:setattr(self, f"nesting_classifier_{0}", nn.Linear(nesting_list[-1], self.num_classes, **kwargs))		else:	for i, num_feat in enumerate(self.nesting_list):setattr(self, f"nesting_classifier_{i}", nn.Linear(num_feat, self.num_classes, **kwargs))	def reset_parameters(self):if self.efficient:self.nesting_classifier_0.reset_parameters()else:for i in range(len(self.nesting_list)):getattr(self, f"nesting_classifier_{i}").reset_parameters()def forward(self, x):nesting_logits = ()for i, num_feat in enumerate(self.nesting_list):if self.efficient:if self.nesting_classifier_0.bias is None:nesting_logits += (torch.matmul(x[:, :num_feat], (self.nesting_classifier_0.weight[:, :num_feat]).t()), )else:nesting_logits += (torch.matmul(x[:, :num_feat], (self.nesting_classifier_0.weight[:, :num_feat]).t()) + self.nesting_classifier_0.bias, )else:nesting_logits +=  (getattr(self, f"nesting_classifier_{i}")(x[:, :num_feat]),)return nesting_logits

借用一张图,很直观:

实验设计

  1. 数据集:实验使用了多个大规模数据集,包括ImageNet-1K、JFT-300M和ALIGN数据集。对于视觉任务,使用了ResNet50和ViT-B/16模型;对于视觉+语言任务,使用了ALIGN模型;对于语言任务,使用了BERT模型。
  2. 实验设置:实验中,MRL和MRL-E模型与独立训练的低维表示(FF)、降维(SVD)、子网络方法(slimmable networks)和随机选择的高容量特征进行比较。实验评估了线性分类/探测(LP)和1-最近邻(1-NN)准确性。
  3. 参数配置:实验中使用的超参数与独立训练的基线模型相同。例如,ResNet50输出2048维表示,ViT-B/16和BERT-Base输出768维嵌入。

本文将MRL/MRL-E模型与单独训练的低维表征(FF),SVD分解,子网络[2]方法进行了比较

首先是分类任务。对于在ImageNet上训练的模型,线性分类准确率基本和FF保持一致,1-NN准确率甚至在低维时高于FF。

对于大规模数据集上训练的模型也取得了很好的精度与速度间的平衡

对于适应性分类,期望的表征大小相比FF减小了14倍。

图像检索的结果也超越了baseline,最高超过了FF 3%。适应性图像检索也达到了效率和精度的权衡,16维度做粗排,2048维度做精排的准确率已经和直接使用2048维度做排序的精度还高,但计算量大幅减小。值得一提的是本文提出了一个漏斗检索方法,即使用逐渐增大的维度16-32-64-128-256-2048 对前200-100-50-25-10个样本的逐步重排,这种方法可以省去调参,应用比较方便。

不足与反思

  1. 嵌套损失权重的优化:未来的工作可以探索自适应损失平衡方法,以实现更优的准确性-效率权衡。
  2. 不同保真度的损失函数:可以考虑使用针对不同保真度的损失函数,以解决特定方面的自适应部署问题,例如高召回率的8维表示和鲁棒的2048维表示。
  3. 搜索数据结构的集成:可以在MRL上学习一个可微分的k-d树,以实现数据集和表示感知的检索。
  4. 多目标MRL的联合优化:结合端到端可学习的搜索数据结构,进行数据驱动的自适应大规模检索,适用于Web规模的搜索应用。

http://www.ppmy.cn/server/120655.html

相关文章

Spring 的循环依赖

在 Spring 中,循环依赖是指两个或多个 Bean 相互依赖,导致在创建过程中出现了依赖死锁的问题。为了解决循环依赖,Spring 引入了三级缓存机制。了解为什么需要三级缓存机制,首先要明白循环依赖是如何发生的,以及两级缓存…

vue-ts-demo

npm i -g vue/cli PS D:\kwai\vue3\project> vue create vue3-te-demo element-plus 一个 Vue 3 UI 框架 | Element Plus https://element-plus.org/zh-CN/guide/installation.html 安装: npm install element-plus --save 完整引入使用: 使用&…

Hive企业级调优[8]—— 其他优化

目录 其他优化 CBO优化 优化说明 优化案例 谓词下推 优化说明 优化案例 矢量化查询 Fetch抓取 本地模式 优化说明 优化案例 并行执行 严格模式 其他优化 CBO优化 优化说明 CBO(Cost Based Optimizer),即基于成本的优化。在Hive中&#…

C++迭代器 iterator详解

目录 什么是迭代器 迭代器的类型 迭代器的用法 三种迭代器 范围for 什么是迭代器 它提供了一种访问容器(如列表、集合等)中元素的方法,而无需暴露容器的内部表示。迭代器使得程序员能够以统一的方式遍历不同的数据结构,而无需…

数据结构与算法——Java实现 10.习题——删除有序链表重复节点

所有无能为力的事情,我都在慢慢接受 —— 24.9.22 83. 删除排序链表中的重复元素 给定一个已排序的链表的头 head , 删除所有重复的元素,使每个元素只出现一次 。返回 已排序的链表 。 示例 1: 输入:head [1,1,2] 输出…

SOCKS5、HTTP 代理IP协议有何区别?

在网络通信领域,代理服务器的选择对于数据安全和传输效率至关重要。SOCKS5代理和HTTP代理作为两种常用的代理类型,各自具有独特的特点和适用场景。本文将深入探讨SOCKS5代理与HTTP代理的区别、特性及应用场景,为用户提供选择指南。 一、SOCK…

矩阵分析第二章内积空间手稿笔记

概念 CS定理 内积长度 向量正交 正交基 构造正交基 内积的坐标表示 正交矩阵 两组正交基的过渡矩阵是正交的 特征值的模为1 正交子空间 正交补 内积空间的同构 正交变换内积不变 证明 点到点的距离 点到子空间的距离 最小二乘法 复内积空间 复内积空间cs定理和证明 三角不等式 …

增强现实系列—Map-Relative Pose Regression for Visual Re-Localization

🌟🌟 欢迎来到我的技术小筑,一个专为技术探索者打造的交流空间。在这里,我们不仅分享代码的智慧,还探讨技术的深度与广度。无论您是资深开发者还是技术新手,这里都有一片属于您的天空。让我们在知识的海洋中…