ViT-vision transformer

news/2024/11/8 22:48:17/

ViT-vision transformer

介绍

Transformer最早是在NLP领域提出的,受此启发,Google将其用于图像,并对分类流程作尽量少的修改。

起源:从机器翻译的角度来看,一个句子想要翻译好,必须考虑上下文的信息!

如:The animal didn’t cross the street because it was too tired将其翻译成中文,这里面就涉及了it这个词的翻译,具体it是指代animal还是street就需要根据上下文来确定,所以现在问题就变成,如何让机器学习上下文?

例如有两个特征,分别为性别和收入,二者做交互特征(简单的说即两个特征相乘),可以得到如:此数据为男人的状态下收入为多少的特征,则可以利用这个特征去分析性别对收入的影响,相对于同时考虑了性别和收入的关系。那么借鉴这个思想,相对于引入一个相乘的交互关系就可以去表示上下文信息了。而Attention在本质上用一句话概括就是:带权重的相乘求和。

在Attention中,假如我们要翻译it这个词,这时候it这个词称为query(Q)待查询。查询什么呢,查询句子中的其他单词包括自己(这里其他的单词包括自己称为(keys(K)),这里的查询操作相对于上文说的相乘,而在Attention中用的是点乘操作。如果还记得Attention的输入是Patch embedding的结果,即是一个个N维空间的向量,即Q和K代表的内容都为N维空间的向量,那么点乘即可以表示这两个向量的相似程度——Q*K = |Q||K|cosθ

Q和K相乘后可以得到一个代表词和词之间相似度的概念,这里记为S。如果我们对这个S取softmax,是不是相对于就得到了当前要查询的Q,到底对应哪个词的概率比较大的概率,这里记为P。

而Attention就是对P做权重加和的结果,而为什么还要对P做权重(这个权重也是可学习的)加和呢,其实我觉得这才是Attention的精髓,因为每个权重即代表了网络对于哪个概率对应下的内容更加注意,对于哪些内容不需要注意,使网络可以更加关注与需要注意的东西,其他无关的东西,通过这个权重,相对于舍弃了。而我们记这个权重为V。

思路:ViT算法中,首先将整幅图像拆分成若干个patch,然后把这些patch的线性嵌入序列作为Transformer的输入送入网络,然后使用监督学习的方式进行图像分类的训练。

具体流程

  1. 将图像拆分成若干个patch
  2. 将patches通过一个线性映射层,得到若干个token embedding
  3. 将多个token embedding concat一个cls_token(可学习参数)
  4. 每个参数均加上position embedding位置编码,防止无法找到原来的位置
  5. 将token embedding、cls_token和position embedding一同传入encoder模块
  6. encoder模块(L个block)
    1. Layer Norm:标准归一化(便于收敛)
    2. MSA/MHA:多头子注意力机制
    3. 输入输出作残差链接
    4. Layer Norm:标准归一化(便于收敛)
    5. MLP:全连接层(Linear+…)
  7. encoder的输出通过MLP Head作分类任务

优点:模型简单且效果好,较好的扩展性,模型越大效果越好。

与CNN结构对比

  • Transformer的平移不变性和局部感知性较差,在数据量不充分时,效果较差
  • 但是对于大量的训练数据,Transformer的效果更佳
  • 无需像CNN构造复杂的网络结构,CNN往往是不断加深网络,才能对刷新某任务的SOTA

在这里插入图片描述

模型结构

图像分块嵌入Patch Embedding

具体步骤:

  1. H ∗ W ∗ C H * W * C HWC的图像,变成一个 N ∗ ( P 2 ∗ C ) N * (P^2*C) N(P2C)的序列,这个序列由一系列展平的图像块构成,即把图像切分成小块后再展平,其中, N = H W / P 2 N=HW/P^2 N=HW/P2个图像块,每个图像块的维度为 P 2 ∗ C P^2*C P2C P P P表示图像块大小, C C C表示通道数量。

  2. 将每个图像块的维度由 P 2 ∗ C P^2*C P2C变换为 D D D,在此进行embedding,只需对每个 P 2 ∗ C P^2*C P2C图像块做一个线性变换,将维度压缩至 D D D

  3. ( N + 1 ) ∗ D (N+1)*D (N+1)D的序列作为encoder的输入。

    为啥是N+1呢?因为要多加上一个维度才能关联到全局的信息,这个恰好是class token

class PatchEmbed(nn.Module):"""2D Image to Patch Embedding"""def __init__(self, img_size=224, patch_size=16, in_c=3, embed_dim=768, norm_layer=None):super().__init__()img_size = (img_size, img_size)patch_size = (patch_size, patch_size)self.img_size = img_sizeself.patch_size = patch_sizeself.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])self.num_patches = self.grid_size[0] * self.grid_size[1]self.proj = nn.Conv2d(in_c, embed_dim, kernel_size=patch_size, stride=patch_size)self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()def forward(self, x):B, C, H, W = x.shapeassert H == self.img_size[0] and W == self.img_size[1], \f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."# flatten: [B, C, H, W] -> [B, C, HW]# transpose: [B, C, HW] -> [B, HW, C]x = self.proj(x).flatten(2).transpose(1, 2)x = self.norm(x)return x

多头自注意力机制Multi-head Self-attention

多头较于单头的优势是增强了网络的稳定性和鲁棒性

( N + 1 ) ∗ D (N+1)*D (N+1)D的序列输入至encoder进行特征提取,其最重要的结构是多头自注意力机制,2 head的multi-head attention结构如下所示,具体步骤如下:

  1. 输入 a i a^i ai经过转移矩阵 W W W,得到 q i , k i , v i q^i,k^i,v^i qi,ki,vi,再分别切分成 q i , 1 , q i , 2 , k i , 1 , k i , 2 , v i , 1 , v i , 2 , q i , 1 . . . q^{i,1},q^{i,2},k^{i,1},k^{i,2},v^{i,1},v^{i,2},q^{i,1}... qi,1,qi,2,ki,1,ki,2,vi,1,vi,2,qi,1...
  2. 接着 q i , j 与 k i , j q^{i,j}与k^{i,j} qi,jki,j做attention,得到权重向量 α α α,将 α α α v i , j v^{i,j} vi,j进行加权求和,最终得到 b i , j b^{i,j} bi,j
  3. b i , j b^{i,j} bi,j拼接起来,通过一个线性层进行处理,得到最终的结果。

具体说说其中的attention, q i , j , k i , j 与 v i , j q^{i,j},k^{i,j}与v^{i,j} qi,j,ki,jvi,j计算 b i , j b^{i,j} bi,j的方法是缩放点积注意力 (Scaled Dot-Product Attention),加权内积得到 α α α
α 1 , i = q 1 ∗ k i d α_{1,i}=\frac{q^1*k^i}{\sqrt{d}} α1,i=d q1ki
其中,d是q和k的维度大小,除以一个 d \sqrt{d} d 可以达到归一化的效果。

接着,将 α 1 , i α_{1,i} α1,i取softmax操作,并与 v i , j v^{i,j} vi,j相乘得到最后结果。

在这里插入图片描述
在这里插入图片描述

class Attention(nn.Module):def __init__(self,dim,   # 输入token的dimnum_heads=8,qkv_bias=False,qk_scale=None,attn_drop_ratio=0.,proj_drop_ratio=0.):super(Attention, self).__init__()self.num_heads = num_headshead_dim = dim // num_headsself.scale = qk_scale or head_dim ** -0.5self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)self.attn_drop = nn.Dropout(attn_drop_ratio)self.proj = nn.Linear(dim, dim)self.proj_drop = nn.Dropout(proj_drop_ratio)def forward(self, x):# [batch_size, num_patches + 1, total_embed_dim]B, N, C = x.shape# qkv(): -> [batch_size, num_patches + 1, 3 * total_embed_dim]# reshape: -> [batch_size, num_patches + 1, 3, num_heads, embed_dim_per_head]# permute: -> [3, batch_size, num_heads, num_patches + 1, embed_dim_per_head]qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)# [batch_size, num_heads, num_patches + 1, embed_dim_per_head]q, k, v = qkv[0], qkv[1], qkv[2]  # make torchscript happy (cannot use tensor as tuple)# transpose: -> [batch_size, num_heads, embed_dim_per_head, num_patches + 1]# @: multiply -> [batch_size, num_heads, num_patches + 1, num_patches + 1]attn = (q @ k.transpose(-2, -1)) * self.scaleattn = attn.softmax(dim=-1)attn = self.attn_drop(attn)# @: multiply -> [batch_size, num_heads, num_patches + 1, embed_dim_per_head]# transpose: -> [batch_size, num_patches + 1, num_heads, embed_dim_per_head]# reshape: -> [batch_size, num_patches + 1, total_embed_dim]x = (attn @ v).transpose(1, 2).reshape(B, N, C)x = self.proj(x)x = self.proj_drop(x)return x

多层感知机Multilayer Perceptron

class Mlp(nn.Module):"""MLP as used in Vision Transformer, MLP-Mixer and related networks"""def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):super().__init__()out_features = out_features or in_featureshidden_features = hidden_features or in_featuresself.fc1 = nn.Linear(in_features, hidden_features)self.act = act_layer()self.fc2 = nn.Linear(hidden_features, out_features)self.drop = nn.Dropout(drop)def forward(self, x):x = self.fc1(x)x = self.act(x)x = self.drop(x)x = self.fc2(x)x = self.drop(x)return x

DropPath

一种特殊的 Dropout,用来替代传统的Dropout结构。作用是:若x为输入的张量,其通道为[B,C,H,W],那么drop_path的含义为在一个Batch_size中,随机有drop_prob的样本,不经过主干,而直接由分支进行恒等映射。

def drop_path(x, drop_prob: float = 0., training: bool = False):if drop_prob == 0. or not training:return xkeep_prob = 1 - drop_probshape = (x.shape[0],) + (1,) * (x.ndim - 1)  # work with diff dim tensors, not just 2D ConvNetsrandom_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)random_tensor.floor_()  # binarizeoutput = x.div(keep_prob) * random_tensorreturn outputclass DropPath(nn.Module):"""Drop paths (Stochastic Depth) per sample  (when applied in main path of residual blocks)."""def __init__(self, drop_prob=None):super(DropPath, self).__init__()self.drop_prob = drop_probdef forward(self, x):return drop_path(x, self.drop_prob, self.training)

Class Token

假设我们将原始图像切分成 3 × 3 = 9个小图像块,最终的输入序列长度却是10,也就是说我们这里人为的增加了一个向量进行输入,我们通常将人为增加的这个向量称为 Class Token。

若没有这个向量,也就是将 N = 9 个向量输入 Transformer 结构中进行编码,我们最终会得到9个编码向量,可对于图像分类任务而言,我们应该选择哪个输出向量进行后续分类呢?两个方案可以实现:

  1. ViT算法提出了一个可学习的嵌入向量 Class Token,将它与9个向量一起输入到 Transformer 结构中,输出10个编码向量,然后用这个 Class Token 进行分类预测即可。
  2. 取除了cls_token之外的所有token的均值作为类别特征表示,即编码中的x[:,self.num_tokens:].mean(dim=1)

Positional Encoding

在self-attention中,输入是一整排的tokens,我们很容易知道tokens的位置信息,但是模型是无法分辨的,因为self-attention的运算是无向的,因此才使用positional encoding把位置信息告诉模型。

按照 Transformer 结构中的位置编码习惯,这个工作也使用了位置编码。不同的是,ViT 中的位置编码没有采用原版 Transformer 中的 sin/cos 编码,而是直接设置为可学习的 Positional Encoding。

MLP Head

得到输出后,ViT中使用了 MLP Head对输出进行分类处理,这里的 MLP Head 由 LayerNorm 和两层全连接层组成,并且采用了 GELU 激活函数。

参考链接:

  1. https://blog.csdn.net/qq_42735631/article/details/126709656?ops_request_misc=&request_id=&biz_id=102&utm_term=vision%20transformer%E6%A8%A1%E5%9E%8B&utm_medium=distribute.pc_search_result.none-task-blog-2allsobaiduweb~default-0-126709656.nonecase&spm=1018.2226.3001.4187
  2. https://blog.csdn.net/aixiaomi123/article/details/128025584?ops_request_misc=&request_id=&biz_id=102&utm_term=vision%20transformer%E6%A8%A1%E5%9E%8B&utm_medium=distribute.pc_search_result.none-task-blog-2allsobaiduweb~default-1-128025584.nonecase&spm=1018.2226.3001.4187
  3. https://github.com/google-research/vision_transformer/tree/main
  4. https://blog.csdn.net/lzzzzzzm/article/details/122963640?ops_request_misc=&request_id=&biz_id=102&utm_term=vit%20transformer%E4%B8%AD%E7%9A%84%E5%A4%9A%E5%A4%B4%E8%87%AA%E6%B3%A8%E6%84%8F%E5%8A%9B%E6%9C%BA%E5%88%B6&utm_medium=distribute.pc_search_result.none-task-blog-2allsobaiduweb~default-1-122963640.nonecase&spm=1018.2226.3001.4187
    02&utm_term=vit%20transformer%E4%B8%AD%E7%9A%84%E5%A4%9A%E5%A4%B4%E8%87%AA%E6%B3%A8%E6%84%8F%E5%8A%9B%E6%9C%BA%E5%88%B6&utm_medium=distribute.pc_search_result.none-task-blog-2allsobaiduweb~default-1-122963640.nonecase&spm=1018.2226.3001.4187
  5. https://blog.csdn.net/weixin_41803874/article/details/125729668

http://www.ppmy.cn/news/988258.html

相关文章

RedHat7.9安装mysql8.0.32 ↝ 二进制方式

RedHat7.9安装mysql8.0.32 ↝ 二进制方式 一、rpm方式安装1、检查是否安装了mariadb2、下载mysqlmysql8.0.323、上传解压4、创建安装目录,拷贝解压后的文件至安装目录/usr/local/mysql8.0/5、创建相关目录,开始安装6、创建mysql组和用户7、更改安装目录归…

Reinforcement Learning with Code 【Chapter 9. Policy Gradient Methods】

Reinforcement Learning with Code This note records how the author begin to learn RL. Both theoretical understanding and code practice are presented. Many material are referenced such as ZhaoShiyu’s Mathematical Foundation of Reinforcement Learning, . 文章…

论文浅尝 | 预训练Transformer用于跨领域知识图谱补全

笔记整理:汪俊杰,浙江大学硕士,研究方向为知识图谱 链接:https://arxiv.org/pdf/2303.15682.pdf 动机 传统的直推式(tranductive)或者归纳式(inductive)的知识图谱补全(KGC)模型都关注于域内(in-domain)数据,而比较少关…

uni-app优雅的实现时间戳转换日期格式

现在显示的格式如下图: 我期望统一格式,所以不妨前端处理一下,核心代码如下 filters: {// 时间戳处理formatDate: function(value, spe /) {value value * 1000let data new Date(value);let year data.getFullYear();let month data.…

[UE4][C++]调整分屏模式下(本地多玩家)视口的显示位置和区域

一、分屏模式设置 在UE4中,多个玩家共用一个显示器就可以启用分屏模式,按玩家人数(最大四人)将屏幕均匀分割,显示不同玩家的视角,开发者可以在编辑器里设置分割类型(水平或者垂直)&a…

小程序 获取用户头像、昵称、手机号的组件封装(最新版)

在父组件引入该组件 <!-- 授权信息 --><auth-mes showModal"{{showModal}}" idautnMes bind:onConfirm"onConfirm"></auth-mes> 子组件详细代码为: authMes.wxml <!-- components/authMes/authMes.wxml --> <van-popup show…

web流程自动化详解

今天给大家带来Selenium的相关解释操作 一、Selenium Selenium是一个用于自动化Web浏览器操作的开源工具和框架。它提供了一组API&#xff08;应用程序接口&#xff09;&#xff0c;可以让开发人员使用多种编程语言&#xff08;如Java、Python、C#等&#xff09;编写测试脚本&…

获评最高级别权威认证!融云通过中国信通院「办公即时通信软件安全能力」评测

点击报名 8 月 3 日&#xff08;周四&#xff09;融云直播课~ 近期&#xff0c;融云再获权威认可&#xff0c;旗下百幄智能在线办公套件平台正式通过中国信通院“办公即时通信软件安全能力”测评&#xff0c;并获得最高级别“卓越级”证书。关注【融云 RongCloud】&#xff0c;…