深度学习(11)---Swin Transformer详解

news/2024/11/13 10:00:35/

文章目录

  • 一、引言
  • 二、结构
  • 三、Patch Merging操作
  • 四、W-MSA详解
  • 五、SW-MSA详解


一、引言

 1. 在原论文中,首先在开头作者就分析,当前的Transformer从NLP迁移到CV上没有大放异彩主要原因集中在:
 (1) 两个领域涉及的规模不同,NLP的规模是标准固定的,而CV的规模变化范围非常大。
 (2) CV比起NLP需要更大的分辨率,而且CV中使用Transformer的计算复杂度是图像尺度的平方,这会导致计算量过于庞大。

 2. 那么本篇论文作者为了解决该问题所用的方法有:
 (1) 引入CNN中常用的层次化构建方式来构建层次化Transformer。
 (2) 提出了 Shifted Windows Multi-Head Self-Attention(SW-MSA)的概念,通过此方法能够让信息在相邻的窗口中进行传递。
 这样一来通过限制在窗口内使用自注意力,带来了更高的效率;并且通过移动,使得相邻两个窗口之间有了交互,上下层之间也就有了跨窗口连接,从而变相达到了一种全局建模的效果。另外层级式的结构不仅非常灵活的去建模各个尺度的信息并且计算复杂度随着图像大小线性增长。
因为有了像卷积神经网络一样的分层结构,有了多尺度的特征,所以很容易的应用到下游任务里,例如图像分类、物体检测、物体分割等。

 3. 与之前的ViT相比,论文中给出了相应的对比图如下,从中可以看出来两者的区别:
 (1) Swin Transformer使用了类似卷积神经网络中的层次化构建方法(Hierarchical feature maps),比如特征图尺寸中有对图像下采样4倍的,8倍的以及16倍的,这样的backbone有助于在此基础上构建目标检测,实例分割等任务。而在之前的Vision Transformer中是一开始就直接下采样16倍,后面的特征图也是维持这个下采样率不变,这样对于多尺寸特征的获得会弱一些。
 (2) Swin-Transformer使用窗口多头自注意力,将特征图划成多个不相交的区域,然后在每个窗口里进行自注意力计算,只要窗口大小固定,自注意力的计算复杂度也是固定的,那么总的计算复杂度就是图像尺寸的线性倍数,而不是Vit对整个特征图进行全局自注意力计算,这样就减少了计算量,但是也隔绝了不同窗口之间的信息交流,随之作者提出后文的移动窗口自注意力计算(Shifted Windows Multi-Head Self-Attention(SW-MSA))。

在这里插入图片描述

 4. 总的来说Swin Transformer是一种改进的ViT,但是Swin Transformer该模型本身具有了划窗操作(包括不重叠的local window和重叠的cross-window),并且具有层级设计。

二、结构

 1. Swin Transformer的网络架构如下图所示:
在这里插入图片描述

 (1) 输入:首先输入还是一张图像数据,224(宽) ∗ 224(高) ∗ 3(通道)。
 (2) 处理过程:通过卷积得到多个特征图,把特征图分成每个Patch,堆叠Swin Transformer Block,与Swin Transformer Block在每次堆叠后长宽减半,特征图个数翻倍。
 (3) Block含义:最核心的部分是对Attention的计算方法做出了改进,每个Block包括了一个W-MSA和一个SW-MSA,成对组合才能串联成一个Block。W-MSA是基于窗口的注意力计算。SW-MSA是窗口滑动后重新计算注意力。

 2. 对于上图的流程解释如下:

  • 输入图片尺寸为H✖️W✖️3(假设是224✖️224✖️3),经过Patch Patition进行分块成为不重合的patch集合,其中每个patch的尺寸为4✖️4大小,在通道方向进行展平变成4✖️4✖️3=48,所以通过Patch Partition后图像尺寸由[H, W, 3](224✖️224✖️3)变成了[H/4, W/4, 48](56✖️56✖️48),patch块的数量为H/4 x W/4(56✖️56),然后通过一个Linear embedding将划分后的patch特征维度变成我们所预制好的值(Transformer能够接受的值),这里设置为超参数C。对于上图Swin-T来说,C=96,即图像shape再由[H/4, W/4, 48]变成了[H/4, W/4, C](56✖️56✖️96)。
  • 接下来如果想有多尺度的信息,那么就要构建一个层级式的Transformer,也就是说我们需要一个像卷积操作中类似于池化的操作,也就是紧接着的Patch Merging操作,shape会由[H/4, W/4, C](56✖️56✖️96)变成[H/8, W/8, C](28✖️28✖️192)具体怎么操作看下文讲解。
  • 然后通过三个阶段构建不同大小的特征图,除了阶段1中先通过一个Linear Embeding层外,剩下三个阶段都是先通过一个Patch Merging层进行下采样。

三、Patch Merging操作

 1. 下图展示Patch merging 的操作过程,顾名思义就是将邻近的小patch合并成一个大patch,这样就可以起到一个下采样特征图的效果了。

在这里插入图片描述

 2. 前面有说,在每个阶段中首先要通过一个Patch Merging层进行下采样(阶段1除外)。如上图所示,假设输入Patch Merging的是一个4✖️4大小的单通道特征图(feature map),Patch Merging会将每个2✖️2的相邻像素划分为一个patch,然后将每个patch中相同位置(同一颜色)像素给拼在一起就得到了4个feature map。接着将这四个feature map在深度方向进行Concat拼接,然后在通过一个LayerNorm层。最后通过一个全连接层在feature map的深度方向做线性变化,将feature map的深度由C变成C/2。通过这个简单的例子可以看出,通过Patch Merging层后,feature map的高和宽会减半,深度会翻倍。

四、W-MSA详解

 引入Windows Multi-head Self-Attention(W-MSA)模块是为了减少计算量。如下图所示,左侧使用的是普通的Multi-head Self-Attention(MSA)模块,对于feature map中的每个像素(或称作token,patch)在Self-Attention计算过程中需要和所有的像素去计算。但在图右侧,在使用Windows Multi-head Self-Attention(W-MSA)模块时,首先将feature map按照M✖️M(例子中的M=2)大小划分成一个个Window,然后单独对每个Window内部进行Self-Attention。
:这样计算量会大大减少。

在这里插入图片描述

五、SW-MSA详解

 1. 前面有说,采用W-MSA模块时,只会在每个窗口内进行自注意力计算,所以窗口与窗口之间是无法进行信息传递的。为了解决这个问题,作者引入了Shifted Windows Multi-Head Self-Attention(SW-MSA)模块,即进行偏移的W-MSA。

 2. 如下图所示,左侧使用的是刚刚讲的W-MSA(假设是第L层),W-MSA和SW-MSA它俩一般是成对使用的,构成一个Swin-Transformer Block,那么右侧图第L+1层使用的就是SW-MSA。根据左右两幅图对比能够发现窗口(Windows)发生了偏移可以理解成窗口从左上角分别向右侧和下方各偏移了 ⌊ M / 2 ⌋ ⌊M/2⌋ M/2 个像素

在这里插入图片描述

在这里插入图片描述

在这里插入图片描述

 3. 根据上图,可以发现通过将窗口进行偏移后,由原来的4个窗口变成9个窗口了。后面又要对每个窗口内部进行MSA,这样做感觉又变麻烦了。为了解决这个麻烦,作者又提出而了Efficient batch computation for shifted configuration,一种针对移位窗口配置的高效批处理计算方法。下面是原论文给的示意图。

在这里插入图片描述

 为了更好的理解可以参考下面的图:

在这里插入图片描述

 最终的结果中,4是一个单独的窗口,5和3合成一个窗口,7和1合成一个窗口,8,6,2,0合成一个窗口,它们大小都一样(图中画的不标准),这样就是4个4✖️4大小的窗口了,所以能够保证计算量是一样的。

 但是这样又有一个问题,把不同区域汇合到一起,它们之间的元素都是从很远的地方搬运过来的,所以它们之间不应该作自注意力,不应该有太多的联系(比如一张图的上面是天空,下面是土地,现在就是把部分天空移到了土地下面,再做自注意力就不太合适了)。所以为了防止这个问题,实际计算中使用掩码操作(Masked MSA),这样就能使用蒙版来隔绝不同区域的信息了,算出自注意力之后再进行还原。


http://www.ppmy.cn/news/1518872.html

相关文章

uniapp实现区域滚动、下拉刷新、上滑滚动加载更多

背景&#xff1a; 在uniapp框架中&#xff0c;有两种实现办法。第1种&#xff0c;是首先在page.json中配置页面&#xff0c;然后使用页面的生命周期函数&#xff1b;第2种&#xff0c;使用<scroll-view>组件&#xff0c;然后配置组件的相关参数&#xff0c;包括但不限于&…

Java面试题:equals和==的区别与联系分别是什么?

1. 运算符 是一个运算符&#xff0c;其用于比较两个变量的内存地址是否相等&#xff1b;对于基本数据类型(int、char、Boolean等)&#xff0c;比较的是它们的值&#xff1b;而对于引用数据类型的话(String、Object、ArrayList等)&#xff0c;比较的是引用&#xff0c;也就是对…

Golang | Leetcode Golang题解之第377题组合总和IV

题目&#xff1a; 题解&#xff1a; func combinationSum4(nums []int, target int) int {dp : make([]int, target1)dp[0] 1for i : 1; i < target; i {for _, num : range nums {if num < i {dp[i] dp[i-num]}}}return dp[target] }

【Kubernetes知识点问答题】第一篇

目录 1.ca-certificates, gnupg, lsb-release 三个包的解释。 2.docker-ce, docker-ce-cli, containerd.io, docker-compose-plugin 作用。 3.K8s 在 1.2 之后就不再支持 docker&#xff0c;请解释对错。 4.举例说明创建容器以及以交互方式访问容器的命令&#xff1f; 1.ca-…

Durid解析SQL语句

在外面的需求中&#xff0c;有很多需要解析SQL语句的地方&#xff0c;我们采用Durid来进行解析。 Durid可以将sql进行详细的拆分成多个部分 解析where解析SQLSelectItem解析update语句解析limit解析group by 还可以动态修改sql&#xff0c;比如在原sql上增加条件修改sql运行的…

libtorch---day03[自定义导数]

参考pytorch。 背景 希望使用勒让德多项式拟合一个周期内的正弦函数。 真值&#xff1a; y s i n ( x ) , x ∈ [ − π , π ] ysin(x),x\in\left[-\pi,\pi\right] ysin(x),x∈[−π,π] torch::Tensor x torch::linspace(-M_PI, M_PI, 2000, torch::kFloat); torch::Ten…

前端配置环境

工具类配置 一、下载Git Bash 下载地址 二、下载google浏览器 下载地址 三、下载微信开发者工具 下载地址 四、下载vscode 下载地址 1、安装中文包 安装中文包 教程 2、安装插件 3、vscode中使用git 教程 4、setting.json 我自己常用的&#xff1a; {"editor.fontSiz…

分布式中间件

1.Nacos 服务注册和服务发现原理图&#xff1a; 1.服务提供方将集群信息注册到Nacos&#xff0c;并定期心跳包提供健康信息&#xff0c;宕机即剔除 2.服务消费方定期拉取订阅信息&#xff0c;获取服务实例列表 3.服务集群的负载均衡是在消费者一方进行选择 负载均衡&#xf…