BLIP-2:salesforce提出基于冻结视觉编码器和LLM模型参数的高效训练多模态大模型

news/2024/10/17 22:12:48/

  • 论文链接:https://arxiv.org/abs/2301.12597

  • 项目代码:https://github.com/salesforce/LAVIS/tree/main/projects/blip2

  • 体验地址:https://huggingface.co/spaces/Salesforce/BLIP2

  • 文档介绍:https://huggingface.co/docs/transformers/main/en/model_doc/blip-2

  • 微调参考:https://github.com/salesforce/LAVIS

  • Huggingface Space地址:https://hf.co/spaces/Salesforce/BLIP2

在过去的几年中,视觉语言预训练(VLP)在模型越来越大的情况下,不断刷新SOTA,但是由于采用端到端的方式训练,在预训练期间需要大量的计算成本。自从2022年11月ChatGPT发布以来,人们越来越意识到大模型的“涌现能力”,尤为突出的就是zero-shot能力(不需要任何fine-tuning就可以直接进行预测)。由于LLM巨大的参数量,fine-tuning的成本很高,因此冻结LLM参数的研究和应用层出不穷,比如adapter策略,LoRA以及本文提到的BLIP-2。

BLIP-2介绍

本文提出的BLIP-2方法是一套新型多模态预训练模型的框架,思路是通过在冻结的预训练图像编码器和冻结的预训练大语言模型之间添加一个轻量级 查询 Transformer (Query Transformer, Q-Former) 来弥合视觉和语言模型之间的模态隔阂 (modality gap)。在整个模型中,Q-Former 是唯一的可训练模块,而图像编码器和语言模型始终保持冻结状态。

采用两阶段进行预训练。在第一个预训练阶段,我们执行视觉语言表示学习,该学习强迫Q-Former学习与文本最相关的视觉特征。在第二个预训练阶段,我们通过将Q-Former的输出连接到LLM,并训练Q-Former,以便通过LLM以自然语言解释其视觉表示,从而执行视觉到语言的生成学习。

BLIP-2的关键优点包括:

BLIP-2有效地利用了参数固定的预训练图像模型和语言模型。使用在两阶段进行预训练的Q-Former弥补了模态间差距:表示学习阶段和生成学习阶段。BLIP-2在各种视觉语言任务上实现了SOTA的性能,包括视觉问答,图像释义和图像文本检索。

在LLM强大能力(例如OPT,Flant5)的支持下,可以提示BLIP-2执行遵循自然语言指令的zero-shot图像到文本生成,从而可以实现新兴功能,例如视觉知识推理,视觉对话等。

由于使用参数固定的单模态模型和轻量级Q-Former,BLIP-2比现有的SOTA更高效。例如,BLIP-2在使用远少于Flamingo 54倍可训练参数情况下,在zero-shot VQAv2上性能超过8.7%。此外,BLIP-2是一种通用方法,可以融入更优的单模态模型,以获得更好的VLP性能。

BLIP-2模型结构

Q-Former作为可训练模块,以桥接图像编码器和LLM之间模态差距。它从图像编码器中提取出固定维度的输出特征,该特征与输入图像的分辨率无关。如图2所示, Q-Former由两个共享相同自注意力层的transformer子模块组成:(1)与参数固定的图像编码器交互的image transformer,用于视觉特征提取;(2)能够同时建模文本编码器和文本解码器的text transformer。创建了一组可学习的query嵌入作为图像transformer的输入。query通过自注意力层相互作用,并通过交叉注意力层与图像特征相互作用。query还可以通过相同的自注意力层与文本相互作用。根据预训练任务,应用不同的自注意力mask来控制query-text间交互。作者使用bert-base的预训练权重来初始化Q-Former,而交叉注意力层是随机初始化的。Q-Former总共包含188M参数。请注意,query嵌入也被视为模型参数。

在实验中,作者使用了32个query嵌入,其中每个query的维度为768(与Q-Former的隐藏维度相同)。Z表示输出的query。Z ( 32 × 768 )的大小比图像特征的大小要小得多(例如ViT-L/14有257 × 1024 )。这些结构与预训练目标一起迫使query能提取出与文本最相关的视觉信息。

BLIP-2表示学习阶段

在表示学习阶段,将Q-Former与参数固定的图像编码器相连,并使用【图像-文本】对进行预训练。目标是训练Q-Former,以便query可以学会提取出与文本最相关的视觉表示。受BLIP的启发,对共享相同输入和模型参数的三个预训练目标进行联合优化。每个目标任务采用了不同的注意力屏蔽策略来以控制query和文本间的相互(见图2)。

  Image-Text Contrastive Learning (ITC):ITC学习对齐图像表示和文本表示,以使它们的相互信息最大化。它通过将【图像-文本】的正例对与负例对进行相似度对比来实现该目标。将来自 image transformer 的query表示Z 与来自text transformer的文本表示t对齐,其中t是[CLS]字符的输出嵌入。由于Z包含多个输出嵌入(每个query一个),因此首先计算每个query输出和t之间的成对相似性,然后选择最高的一个作为图像文本的相似度。为了避免信息泄漏,采用了单模态自注意力屏蔽矩阵。与端到端方法相比,由于使用冻结参数的图像编码器,每个GPU可以训练更多的训练样本。因此,我们使用in-batch负例,而不是BLIP中的动量队列。

  Image-grounded Text Generation (ITG):ITG损失训练Q-Former以生成文本,其以输入图像为条件。由于Q-Former的结构不允许在图像编码器和文本字符之间进行直接交互,因此生成文本所需的信息必须首先由query提取,然后通过自注意力层传递给文本字符。因此,query被迫提取视觉特征,以捕获文本相关的所有信息。我们采用多模态因果自注意力屏蔽矩阵来控制query与文本间的相互作用,类似于UniLM。queries可以彼此互相看到,但不能无法看到文本字符。每个文本字符都可以看到所有query及其先前的文本字符。还将[CLS]字符替换为新的[DEC]来作为第一个文本,以发出解码任务的信号。

  Image-Text Matching (ITM):ITM旨在学习图像和文本表示之间的细粒度对齐。这是一个二分类任务,要求模型预测【图像-文本】对是正例(匹配)还是负例(不匹配)。使用双向自注意力屏蔽矩阵,这时所有query和文本都可以互相看到。query嵌入Z因此捕获了多模态信息。将每个query嵌入映射到两分类的线性分类器中以获得logits,并对所有query的logits进行平均,以输出匹配分数。

BLIP-2生成式预训练阶段

在生成式预训练阶段,将Q-Former(带有参数固定的图像编码器)连接到参数固定的LLM,以捕获LLM的生成能力。如图3所示,使用全连接层(FC)线性地将query嵌入Z映射到与LLM文本嵌入相同的维度。然后将映射后的query嵌入拼接到到输入文本嵌入前,这充当了一种软提示,LLM以从Q-former中抽取的视觉表示为条件进行后续生成。由于Q-Former已经进行了预训练以提取语言相关的视觉表示,因此它可以有效地作为信息载体,在删除无关的视觉信息的同时,将最有用的信息提供给LLM。这减轻了LLM学习视觉语言对齐的负担,从而减轻了灾难性遗忘的问题。

  作者尝试了两种类型的LLM:基于decoder的LLM和基于encoder-decoder的LLM。对于基于decoder的LLM,使用语言建模损失进行预训练,其中冻结的LLM的任务是生成以Q-Former的视觉表示为条件的文本。对于基于encoder-decoder的LLM,使用前缀语言建模损失进行预训练,并且将文本分为两个部分,前缀文本与视觉表示作为LLM编码器的输入,后缀文本用作LLM解码器的生成目标。

BLIP-2使用

Step1、安装transformer环境

pip install git+https://github.com/huggingface/transformers.git

Step2、收集测试数据

我们需要一个输入图像。《纽约客》每周都会面向其读者举办一场 卡通字幕比赛。我们从中取一张卡通图像输入给 BLIP-2 用于测试。

卡通字母比赛链接:
https://www.newyorker.com/cartoons/contest#thisweek

import requests
from PIL import Imageurl = 'https://media.newyorker.com/cartoons/63dc6847be24a6a76d90eb99/master/w_1160,c_limit/230213_a26611_838.jpg'
image = Image.open (requests.get (url, stream=True).raw).convert ('RGB')  
display (image.resize ((596, 437)))

New Yorker Cartoon

Step3、加载模型

现在我们有一张输入图像了,还需要一个预训练过的 BLIP-2 模型和相应的预处理器来处理输入。你可以在 Hugging Face Hub 上找到所有可用的预训练 checkpoints 列表。这里,我们将加载一个使用 Meta AI 的预训练 OPT 模型的 BLIP-2 checkpoint,该 OPT 模型具有 27 亿个参数。

from transformers import AutoProcessor, Blip2ForConditionalGeneration
import torchprocessor = AutoProcessor.from_pretrained ("Salesforce/blip2-opt-2.7b")
model = Blip2ForConditionalGeneration.from_pretrained ("Salesforce/blip2-opt-2.7b", torch_dtype=torch.float16)

请注意,你暂时还无法使用 Auto API (例如 AutoModelForXXX) 来加载 BLIP-2 模型,这种情况在 Hugging Face 中比较少见。你需要显式使用 Blip2ForConditionalGeneration 来加载 BLIP-2 模型。虽然自动获取模型还不能做到,但是你可以使用 AutoProcessor 来获取匹配的处理器类,在本例中为 Blip2Processor

我们可以使用 GPU 来加快文本生成速度:

device = "cuda" if torch.cuda.is_available () else "cpu"
model.to (device)

下面看一些具体的案例

图像字幕生成

我们先看看 BLIP-2 是否可以零样本地为《纽约客》卡通图像生成字幕。要为图像添加字幕,我们不必向模型提供任何文本提示,仅提供预处理过的输入图像。没有任何文字提示,模型将从 BOS (beginning-of-sequence) 开始生成图像字幕。

inputs = processor (image, return_tensors="pt")generated_ids = model.generate (**inputs, max_new_tokens=20)
generated_text = processor.batch_decode (generated_ids, skip_special_tokens=True)[0].strip ()
print (generated_text)
"two cartoon monsters sitting around a campfire"

对于未使用《纽约客》风格的卡通图像训练过的模型,这是一个令人印象深刻的准确描述!

有提示图片字幕生成

我们还可以通过提供文本提示来扩展图像字幕生成,模型将在给定图像的情况下接着提示词往下补充。

prompt = "this is a cartoon of"inputs = processor (image, text=prompt, return_tensors="pt").to (device, torch.float16)generated_ids = model.generate (**inputs, max_new_tokens=20)
generated_text = processor.batch_decode (generated_ids, skip_special_tokens=True)[0].strip ()
print (generated_text)
"two monsters sitting around a campfire"
prompt = "they look like they are"inputs = processor (image, text=prompt, return_tensors="pt").to (device, torch.float16)generated_ids = model.generate (**inputs, max_new_tokens=20)
generated_text = processor.batch_decode (generated_ids, skip_special_tokens=True)[0].strip ()
print (generated_text)
"having a good time"

视觉问答

用于视觉问答时,提示必须遵循特定格式: "Question: {} Answer:"

prompt = "Question: What is a dinosaur holding? Answer:"inputs = processor (image, text=prompt, return_tensors="pt").to (device, torch.float16)generated_ids = model.generate (**inputs, max_new_tokens=10)
generated_text = processor.batch_decode (generated_ids, skip_special_tokens=True)[0].strip ()
print (generated_text)
"A torch"

基于聊天的提示

最后,我们可以通过拼接对话中每轮的问题和回答来创建类似 ChatGPT 的体验。我们用某个提示 (比如 “恐龙拿着什么?”) 来问模型,模型会为它生成一个答案 (如 “火炬”),我们可以把这一问一答拼接到对话中。然后我们再来一轮,这样就把上下文 (context) 建立起来了。但是,需要确保的是,上下文不能超过 512 个标记,因为这是 BLIP-2 使用的语言模型 (OPT 和 T5) 的上下文长度。

context = [("What is a dinosaur holding?", "a torch"),("Where are they?", "In the woods.")
]
question = "What for?"
template = "Question: {} Answer: {}."prompt = "".join ([template.format (context [i][0], context [i][1]) for i in range (len (context))]) +" Question: "+ question +" Answer:"print (prompt)
Question: What is a dinosaur holding? Answer: a torch. Question: Where are they? Answer: In the woods.. Question: What for? Answer:

inputs = processor (image, text=prompt, return_tensors="pt").to (device, torch.float16)generated_ids = model.generate (**inputs, max_new_tokens=10)
generated_text = processor.batch_decode (generated_ids, skip_special_tokens=True)[0].strip ()
print (generated_text)
To light a fire.

参考文献:

[1] https://baijiahao.baidu.com/s?id=1759140009156263839&wfr=spider&for=pc


http://www.ppmy.cn/news/80340.html

相关文章

seatunnel 2.3.1全流程部署使用

Seatunnel 2.3.1 部署使用 1 部署1.1 下载解压1.2 下载对应的connector1.3 安装seatunnel⭐1.4 补充一些jar包 2 测试样例2.1 官方demo fake to console2.2 mysql to console2.3 hive to console2.4 mysql to hive 3 欢迎讨论 1 部署 1.1 下载解压 https://dlcdn.apache.org/…

【分立元件】MOSFET如何用于同步整流

在电力电子中我们会使用二极管做开关,当二极管导时,相当于开关闭合,当二极管截止时,相当于开关断开。但是二极管在导通时的管压降在低压电源电路中是一个损耗来源,所以一般我们首选使用的是肖特基二极管,因为肖特基二极管的管压降比较低。 如下所示为非同步BUCK电源拓朴…

Java中的反射以及使用方法

一. 简介 在程序运行阶段, 动态获取一个类中的属性或者方法, 把这种机制成为反射机制. 可以说, 没有反射就没有Java的任何框架 二. 应用 产生对象 假设有一个Student对象 public class Student {private String name;public int age;static int nationality;public Studen…

使用MinIO文件存储系统【完成视频断点续传】业务逻辑

目录 视频上传 接口一:检查该视频/媒资文件是否已经上传完成 接口二:检查视频分块是否已经在minio中已经存在 接口三:上传分块文件到minio中(已经上传的分块会在接口二进行校验) 接口四:合并上传的分块…

天猫订单之数据分析与挖掘——分类分析

天猫订单之数据分析与挖掘——分类分析 文章目录 天猫订单之数据分析与挖掘——分类分析0. 写在前面1. 分类分析1.1 决策树预测1.2 随机森林1.3 朴素贝叶斯算法0. 写在前面 Windows:Windows10Python:Python3.9本次案例项目主要是采用Pandas和Numpy对天猫订单数据集进行处理、…

linux共享内存总结

共享内存函数由shmget、shmat、shmdt、shmctl四个函数组成 头文件&#xff1a; #include <sys/ipc..h> #include<sys/shm.h> // 创建或获取一个共享内存: 成功返回共享内存ID&#xff0c;失败返回-1 int shmget (key_t key, size_t_size, int flag); // 连接共享内…

图片翻译怎么弄?如何把图片翻译成中文?

在使用社交媒体时&#xff0c;可能会遇到来自世界各地的异文化信息&#xff0c;这时我们可以借助图片翻译的方法帮助我们更好地了解这些信息&#xff0c;促进跨文化交流。那么图片翻译怎么弄呢&#xff1f;图片翻译的方法有哪些呢&#xff1f;这篇文章给你推荐三个非常好用的图…