siglip代码笔记

embedded/2024/12/22 22:56:27/

Github
siglip-so400m-patch14-384 使用了SoViT-400m结构,SoViT :a shape-optimized vision transformer,结构参数经过试验测试得到。具体见 Getting ViT in Shape: Scaling Laws for Compute-Optimal Model Design

We validate these predictions by optimizing the shape of ViT for the compute-equivalent of ViT-g/14 when the latter is pretrained on 16 billion JFT-3B examples as done in [80]. The resulting model, SoViT-400m/14, is significantly smaller and faster, yet equally competitive. It has a width of 1152, depth 27, and MLP dim 4304. Fine-tuning it on ImageNet results in a 90.3% top-1 accuracy, see Figure 2. Section 5 presents various other evaluations.

一 代码

使用siglip-so400m-patch14-384 模型,使用modelscope下载。

#模型下载
from modelscope import snapshot_download
model_dir = snapshot_download('fireicewolf/siglip-so400m-patch14-384')from PIL import Image
import requests
from transformers import AutoProcessor, AutoModel
import torch
model = AutoModel.from_pretrained(model_dir)
processor = AutoProcessor.from_pretrained(model_dir)
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
image = Image.open(requests.get(url, stream=True).raw)
texts = ["a photo of 2 cats", "a photo of 2 dogs"]
# important: we pass `padding=max_length` since the model was trained with this
inputs = processor(text=texts, images=image, padding="max_length", return_tensors="pt")
with torch.no_grad():outputs = model(**inputs)
logits_per_image = outputs.logits_per_image
probs = torch.sigmoid(logits_per_image) # these are the probabilities
print(f"{probs[0][0]:.1%} that image 0 is '{texts[0]}'")

二 预处理

下载的示例图像大小为(640, 480),需要先处理图像和文本,预处理参数如下。
图像经过处理后,shape为torch.Size([1, 3, 384, 384])。
文本tokenizer后,shape为torch.Size([2, 64])
inputs包含了这两个值([‘input_ids’, ‘pixel_values’])

{"do_normalize": true,"do_rescale": true,"do_resize": true,"image_mean": [0.5,0.5,0.5],"image_processor_type": "SiglipImageProcessor","image_std": [0.5,0.5,0.5],"processor_class": "SiglipProcessor","resample": 3,"rescale_factor": 0.00392156862745098,"size": {"height": 384,"width": 384}
}

三 模型预测

# 图像部分,输出包括 last_hidden_state和pooler_output,shape分别为[1, 729, 1152]和[1, 1152]vision_outputs = self.vision_model(pixel_values=pixel_values,output_attentions=output_attentions,output_hidden_states=output_hidden_states,return_dict=return_dict,interpolate_pos_encoding=interpolate_pos_encoding,)
#文本部分,输出包括 last_hidden_state和pooler_output,shape分别为[2, 64, 1152]和[2, 1152]text_outputs = self.text_model(input_ids=input_ids,attention_mask=attention_mask,position_ids=position_ids,output_attentions=output_attentions,output_hidden_states=output_hidden_states,return_dict=return_dict,)
#都取最后的pooler_outputimage_embeds = vision_outputs[1]text_embeds = text_outputs[1]
#normalized# normalized featuresimage_embeds = image_embeds / image_embeds.norm(p=2, dim=-1, keepdim=True)text_embeds = text_embeds / text_embeds.norm(p=2, dim=-1, keepdim=True)
#logits_per_text为tensor([[-0.3843],[-9.7228]])
#logits_per_image为tensor([[-0.3843, -9.7228]])# cosine similarity as logitslogits_per_text = (torch.matmul(text_embeds, image_embeds.t().to(text_embeds.device)) * self.logit_scale.exp()+ self.logit_bias)logits_per_image = logits_per_text.t()loss = Noneif return_loss:# Adapted from https://github.com/google-research/big_vision/blob/01edb81a4716f93a48be43b3a4af14e29cdb3a7f/big_vision/trainers/proj/image_text/siglip.py#L287
# 计算losseye = torch.eye(logits_per_text.size(0), device=logits_per_text.device)# tensor([[1., 0.],#         [0., 1.]])m1_diag1 = -torch.ones_like(logits_per_text) + 2 * eye# tensor([[ 1., -1.],#        [-1.,  1.]])loglik = torch.nn.functional.logsigmoid(m1_diag1 * logits_per_text)# tensor([[-9.0365e-01, -5.1934e-01],#       [-5.9898e-05, -9.7229e+00]])nll = -torch.sum(loglik, dim=-1)# tensor([1.4230, 9.7230])loss = nll.mean()# tensor(5.5730)if not return_dict:output = (logits_per_image, logits_per_text, text_embeds, image_embeds, text_outputs, vision_outputs)return ((loss,) + output) if loss is not None else outputreturn SiglipOutput(loss=loss,logits_per_image=logits_per_image,logits_per_text=logits_per_text,text_embeds=text_embeds,image_embeds=image_embeds,text_model_output=text_outputs,vision_model_output=vision_outputs,)

四 vision model

输入数据为(1,3,384,384),经过conv2d卷积,也就是把图片转换成序列,384/14 = 27.42857,这里没有padding,取整后就是27*27 = 729,Embedding后就是(1,729,1152).

        hidden_states = self.embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding)
# SiglipVisionEmbeddings(
#  (patch_embedding): Conv2d(3, 1152, kernel_size=(14, 14), stride=(14, 14), padding=valid)
#  (position_embedding): Embedding(729, 1152)
#)encoder_outputs = self.encoder(inputs_embeds=hidden_states,output_attentions=output_attentions,output_hidden_states=output_hidden_states,return_dict=return_dict,)last_hidden_state = encoder_outputs[0]last_hidden_state = self.post_layernorm(last_hidden_state)pooler_output = self.head(last_hidden_state) if self.use_head else Noneif not return_dict:return (last_hidden_state, pooler_output) + encoder_outputs[1:]return BaseModelOutputWithPooling(last_hidden_state=last_hidden_state,pooler_output=pooler_output,hidden_states=encoder_outputs.hidden_states,attentions=encoder_outputs.attentions,)

五 text model

        input_shape = input_ids.size()input_ids = input_ids.view(-1, input_shape[-1])hidden_states = self.embeddings(input_ids=input_ids, position_ids=position_ids)# note: SigLIP's text model does not use a causal mask, unlike the original CLIP model.# expand attention_maskif attention_mask is not None and not self._use_flash_attention_2:# [batch_size, seq_len] -> [batch_size, 1, tgt_seq_len, src_seq_len]attention_mask = _prepare_4d_attention_mask(attention_mask, hidden_states.dtype)encoder_outputs = self.encoder(inputs_embeds=hidden_states,attention_mask=attention_mask,output_attentions=output_attentions,output_hidden_states=output_hidden_states,return_dict=return_dict,)last_hidden_state = encoder_outputs[0]last_hidden_state = self.final_layer_norm(last_hidden_state)# Assuming "sticky" EOS tokenization, last token is always EOS.pooled_output = last_hidden_state[:, -1, :]pooled_output = self.head(pooled_output)if not return_dict:return (last_hidden_state, pooled_output) + encoder_outputs[1:]return BaseModelOutputWithPooling(last_hidden_state=last_hidden_state,pooler_output=pooled_output,hidden_states=encoder_outputs.hidden_states,attentions=encoder_outputs.attentions,)

vision model 和text model 的 encoder 部分结构相同,但是不共享参数。

ModuleList((0-26): 27 x SiglipEncoderLayer((self_attn): SiglipSdpaAttention((k_proj): Linear(in_features=1152, out_features=1152, bias=True)(v_proj): Linear(in_features=1152, out_features=1152, bias=True)(q_proj): Linear(in_features=1152, out_features=1152, bias=True)(out_proj): Linear(in_features=1152, out_features=1152, bias=True))(layer_norm1): LayerNorm((1152,), eps=1e-06, elementwise_affine=True)(mlp): SiglipMLP((activation_fn): PytorchGELUTanh()(fc1): Linear(in_features=1152, out_features=4304, bias=True)(fc2): Linear(in_features=4304, out_features=1152, bias=True))(layer_norm2): LayerNorm((1152,), eps=1e-06, elementwise_affine=True))
)

http://www.ppmy.cn/embedded/147922.html

相关文章

MFC/C++学习系列之简单记录11——树控件的使用

MFC/C学习系列之简单记录11——树控件的使用 前言CTreectrl使用界面设置代码使用简单设计其他使用注意! 总结 前言 在之前的界面设计中使用得很少,但是可以学习一下,以备不时之需! CTreectrl使用 界面设置 在工具箱中选择Tree C…

数据结构—图

目录 一、图的定义 二、图的基本概念和术语 2.1有向图 2.2无向图 2.3简单图 2.4多重图 2.5完全图 2.6子图 2.7连通、连通图和连通分量 2.8强连通图、强联通分量 2.9生成树,生成森林 2.10顶点的度、入度和出度 2.11边的权和网 2.12稠密图、稀疏图 2.1…

【docker】容器编排之docker swarm

Docker Swarm容器编排详细讲解 Docker Swarm是Docker的原生容器编排工具,它通过将多个Docker引擎组合成一个集群来实现高效的容器部署和管理。 Swarm提供了服务发现、负载均衡、扩展、自动恢复等功能,能够让开发者和运维人员以更简便的方式管理容器化应…

VMWare 的克隆操作

零、碎碎念 VMWare 的这个克隆操作很简单,单拎出来成贴的目的是方便后续使用。 一、操作步骤 1.1、在“源”服务器上点右键,选择“管理--克隆” 1.2、选择“虚拟机的当前状态”为基础制作克隆,如下图所示,然后点击“下一页” 1.3、…

亚矩阵云手机:跨境直播的超强助力

在跨境直播的蓬勃浪潮中,网络卡顿、延迟以及诸多技术难题犹如重重迷雾,困扰着众多从业者,阻碍着业务的拓展与流量的获取。而亚矩阵云手机的出现,恰似一盏明灯,为跨境直播照亮了前行的道路,凭借其卓越的特性…

瑞吉外卖项目学习笔记(二)Swagger、logback、表单校验和参数打印功能的实现

瑞吉外卖项目学习笔记(一)准备工作、员工登录功能实现 文章目录 3 项目组件优化3.1 实现Swagger文档输出3.2 实现logback日志打印3.3 实现表单校验功能3.4 实现请求参数和响应参数的打印 3 项目组件优化 3.1 实现Swagger文档输出 1)在application.yml中增加knife4…

Flink调优----反压处理

目录 概述 1.1 反压的理解 1.2 反压的危害 定位反压节点 2.1 利用 Flink Web UI 定位 通过 WebUI 看到 Map 算子处于反压:​编辑 分析瓶颈算子 2.2 利用 Metrics 定位 根据指标分析反压 可以进一步分析数据传输 反压的原因及处理 3.1 查看是否数据倾斜 …

Android settings命令详解

文章目录 Android 中的 settings 命令详细介绍基本语法使用示例1. 查看设置值2. 修改设置值3. 删除设置项 命令选项1. get 子命令2. put 子命令3. delete 子命令 命名空间详解1. system2. secure3. global 常见设置项全局设置(global)安全设置&#xff0…