基础论文学习(3)——SwinTransformer

news/2024/11/17 15:56:17/

目前Transformer应用到图像领域的挑战:

  • 图像分辨率高,像素点多,如果需要更多特征就必须构建很长的序列,但Transformer基于全局自注意力的计算导致计算量较大,能否用窗口+分层的形式代替长序列,实现类似CNN感受野的效果?

针对上述问题,我们提出了一种包含滑窗操作,具有层级设计的Swin Transformer,逐层合并tokens。
在这里插入图片描述

其中滑窗操作包括不重叠的local window + 重叠的cross-window将注意力计算限制在一个窗口中,一方面能引入CNN卷积操作的局部性,另一方面能节省计算量

1. SwinTransformer总体架构

整个模型采取层次化的设计,一共包含4个Stage,每个stage都会缩小输入特征图的分辨率,像CNN一样逐层扩大感受野。

  • 在输入开始的时候,做了一个Patch Embedding,将图片切成一个个图块(对image进行卷积,然后对特征图切分为patch),并嵌入到Embedding,构建token序列。
  • 在每个Stage里,由Patch Merging和多个Block组成。
  • 其中Patch Merging模块主要在每个Stage一开始进行下采样(W和H不断减小,C不断增大),降低图片分辨率。
  • 而Block具体结构如右图所示,主要是LayerNormMLPWindow AttentionShifted Window Attention组成 (提供了2种attention计算方法)
    在这里插入图片描述
class SwinTransformer(nn.Module):def __init__(...):super().__init__()...# absolute position embeddingif self.ape:self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))self.pos_drop = nn.Dropout(p=drop_rate)# build layersself.layers = nn.ModuleList()for i_layer in range(self.num_layers):layer = BasicLayer(...)self.layers.append(layer)self.norm = norm_layer(self.num_features)self.avgpool = nn.AdaptiveAvgPool1d(1)self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()def forward_features(self, x):
# step1:Patch Embeddingx = self.patch_embed(x)if self.ape:x = x + self.absolute_pos_embedx = self.pos_drop(x)# step2:BasicLayer = feature_shift + Window Partition + W-MSA/SW-MSA + Window Reverse + reverse_shift + Patch Mergingfor layer in self.layers:  # 遍历4个stagex = layer(x)  # step3:LN + AvgPool + flattenx = self.norm(x)  # (B, L, C)->(B, L, C)x = self.avgpool(x.transpose(1, 2))  # (B, L, C)->(B, C, 1)x = torch.flatten(x, 1)  # (B, C, 1)->(B, C)return xdef forward(self, x):x = self.forward_features(x)
# step4:FC(不同任务的Head层不同)x = self.head(x)  # # (B, C)->(B, num_class)return x

其中有几个地方处理方法与ViT不同:

  • ViT在输入会给embedding进行位置编码。而Swin-T这里则是作为一个可选项(self.ape),Swin-T是在计算Attention的时候做了一个相对位置编码
  • ViT会单独加上一个可学习参数,作为分类的token。而Swin-T则是直接做平均,输出分类,有点类似CNN最后的全局平均池化层

1.1 Patch Embedding

在输入进Block前,我们需要将图片切成一个个patch,然后嵌入向量。

具体做法是对原始图片(224,224,3)裁成一个个 patch_size * patch_size的窗口大小,然后进行嵌入。

这里可以将stride=4,kernel_size=4设置为patch_size=4大小,按照VIT中patch embedding的方式(不重叠卷积)得到每一个图像块patch对应长度为embed_dim的向量。设定输出通道来确定嵌入向量的大小。最后将H,W维度展开,并移动到第一维度。输出(3136, 96)相当于3136个长度为96的token,j将tokens序列排列为正方形即(56*56, 96)

import torch
import torch.nn as nnclass PatchEmbed(nn.Module):def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):super().__init__()img_size = to_2tuple(img_size) # -> (img_size, img_size)patch_size = to_2tuple(patch_size) # -> (patch_size, patch_size)patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]self.img_size = img_sizeself.patch_size = patch_sizeself.patches_resolution = patches_resolutionself.num_patches = patches_resolution[0] * patches_resolution[1]self.in_chans = in_chansself.embed_dim = embed_dimself.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)  # 这里!!if norm_layer is not None:self.norm = norm_layer(embed_dim)else:self.norm = Nonedef forward(self, x):# 假设采取默认参数x = self.proj(x) # 出来的是(N, 96, 224/4, 224/4) x = torch.flatten(x, 2) # 把HW维展开,(N, 96, 56*56)x = torch.transpose(x, 1, 2)  # 把通道维放到最后 (N, 56*56, 96)if self.norm is not None:x = self.norm(x)return x

1.2 Window Partition/Reverse

window partition函数是用于对张量按非重叠窗口大学window_size划分为一条条tokens,指定窗口大小。将原本的张量从 N H W C, 划分成 num_windows*B, window_size, window_size, C,其中 num_windows = H*W / (window_size*window_size),即窗口的个数。

如输入特征图(56,56,96),默认window_size=7x7,所以分为8x8个窗口,num_windows=64,输出特征图(64, 7, 7, 96),之前的单位是token(共56x56=3136个token),现在的单位是窗口(共8x8=64个window,每个window聚集了7x7=49个token),最后把每个window内的token聚合展平为一个大token,每个大token的shape=(49,96)

window reverse函数则是对应的逆过程。这两个函数会在后面的Window Attention用到。
在这里插入图片描述
实现起来,window partition和window reverse没有可学习参数,因而不需要继承其他的类,写成函数就行。上面windows_partition是将送进来的特征进行window_size的划分,最终变为一条条tokens(对应示意图!!!)

def window_partition(x, window_size):B, H, W, C = x.shapex = x.view(B, H // window_size, window_size, W // window_size, window_size, C)windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)return windowsdef window_reverse(windows, window_size, H, W):B = int(windows.shape[0] / (H * W / window_size / window_size))x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)return x

1.3 W-MSA 和 SW-MSA

两者串联起来就是一个Swin Transformer Block:

  • W-MSA 窗口多头自注意力机制(windows multi-head self attention):窗口内部multi-head self-attention
  • SW-MSA 滑动窗口多头自注意力机制(shift windows multi-head self attention):窗口之间multi-head self-attention
    在这里插入图片描述

W-MSA

传统的Transformer都是基于全局来计算注意力的,因此计算复杂度十分高。而Swin Transformer则将注意力的计算限制在每个窗口内,进而减少了计算量。

输入特征图(64, 7, 7, 96),window size=7(包含7x7个长度96的token),共64个窗口。
在这里插入图片描述

swin transformer是按照window size内的小方格计算self-attention的,比如上图中的windows size=7,也就是每7*7个tokens(红色框)之间计算多头self-attention(head=3)。

一次性用x计算出qkv三个矩阵:3个qkv矩阵放在一起的shape=(3, 64, 3, 49, 32),3个矩阵,64个window,head=3, 窗口大小=7x7=49,每个head特征长度96/3=32,64个窗口自己的attention结果是(64, 3, 49, 49)。

这里注意,计算self-attention的输入tokens的数量和维度都不变换,因此最终的输出特征图依旧是(64, 49, 96),64个窗口,每个窗口7x7个token,每个96维的token都会学习到了窗口内的自注意力。

SW-MSA

前面的Window Attention是在每个窗口下计算注意力的,为了更好的和其他window进行信息交互,Swin Transformer还引入了shifted window操作。

左边是没有重叠的Window Attention,而右边则是将窗口进行移位的Shift Window Attention。可以看到移位后的窗口包含了原本相邻窗口的元素。但这也引入了一个新问题,即window的个数翻倍了,由原本4个窗口变成了9个窗口。

在这里插入图片描述
在实际代码里,我们是通过对特征图移位,并给Attention设置mask来间接实现的。能在保持原有的window个数下,最后的计算结果等价
在这里插入图片描述

特征图移位+Mask操作

对特征图位移(torch.roll)之后,还是按照4个窗口计算attention,但是会有冗余计算结果,直接设置对应位置mask为负无穷(softmax后为0),忽略不需要的attetion部分(图中灰色部分),输出的结果同W-MSA 也是(56, 56, 96,不要忘记计算完对特征图还原平移)。
在这里插入图片描述

我们看下Block的前向代码:

def forward(self, x):H, W = self.input_resolutionB, L, C = x.shapeassert L == H * W, "input feature has wrong size"shortcut = xx = self.norm1(x)x = x.view(B, H, W, C)# cyclic shiftif self.shift_size > 0:shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))else:shifted_x = x# partition windowsx_windows = window_partition(shifted_x, self.window_size)  # nW*B, window_size, window_size, Cx_windows = x_windows.view(-1, self.window_size * self.window_size, C)  # nW*B, window_size*window_size, C# W-MSA/SW-MSAattn_windows = self.attn(x_windows, mask=self.attn_mask)  # nW*B, window_size*window_size, C# merge windowsattn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)shifted_x = window_reverse(attn_windows, self.window_size, H, W)  # B H' W' C# reverse cyclic shiftif self.shift_size > 0:x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))else:x = shifted_xx = x.view(B, H * W, C)# FFNx = shortcut + self.drop_path(x)x = x + self.drop_path(self.mlp(self.norm2(x)))return x

整体流程如下

  • 先对特征图进行LayerNorm
  • 通过self.shift_size决定是否需要对特征图进行shift
  • 然后将特征图切成一个个窗口
  • 计算Attention,通过self.attn_mask来区分Window Attention还是Shift Window Attention
  • 将各个窗口合并回来
  • 如果之前有做shift操作,此时进行reverse shift,把之前的shift操作恢复
  • 做dropout和残差连接
  • 再通过一层LayerNorm+全连接层,以及dropout和残差连接

1.4 Patch Merging

该模块的作用是在每个Stage开始前做降采样,用于缩小分辨率,调整通道数 进而形成层次化的设计,同时也能节省一定运算量。
在这里插入图片描述

在CNN中,则是在每个Stage开始前用stride=2的卷积/池化层来降低分辨率。

每次降采样是两倍,因此在行方向和列方向上,间隔2选取元素。

然后拼接在一起作为一整个张量,最后展开。此时通道维度会变成原先的4倍(因为H,W各缩小2倍),此时再通过一个全连接层再调整通道维度为原来的两倍。如输入(56, 56, c),变为(28, 28, 4c),全连接输出(28, 28, 2c),这样就使得下一个stage的窗口数量减少了。

class PatchMerging(nn.Module):def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm):super().__init__()self.input_resolution = input_resolutionself.dim = dimself.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)self.norm = norm_layer(4 * dim)def forward(self, x):"""x: B, H*W, C"""H, W = self.input_resolutionB, L, C = x.shapeassert L == H * W, "input feature has wrong size"assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even."x = x.view(B, H, W, C)x0 = x[:, 0::2, 0::2, :]  # B H/2 W/2 Cx1 = x[:, 1::2, 0::2, :]  # B H/2 W/2 Cx2 = x[:, 0::2, 1::2, :]  # B H/2 W/2 Cx3 = x[:, 1::2, 1::2, :]  # B H/2 W/2 Cx = torch.cat([x0, x1, x2, x3], -1)  # B H/2 W/2 4*Cx = x.view(B, -1, 4 * C)  # B H/2*W/2 4*Cx = self.norm(x)x = self.reduction(x)return x

下面是一个示意图(输入张量B=1, H=W=8, C=1,不包含最后的全连接层调整)
在这里插入图片描述

2. 实验分析

在这里插入图片描述
在ImageNet22K数据集上,准确率能达到惊人的86.4%。另外在检测,分割等任务上表现也很优异。这篇文章创新点很棒,引入window这一个概念,将CNN的局部性引入,还能控制模型整体计算量。在Shift Window Attention部分,用一个mask和移位操作,很巧妙的实现计算等价。作者的代码也写得十分赏心悦目,推荐阅读!


http://www.ppmy.cn/news/1049980.html

相关文章

[Go版]算法通关村第十三关白银——数组实现加法和幂运算

目录 数组实现加法专题题目:数组实现整数加法思路分析:复杂度:Go代码 题目:字符串加法思路分析:复杂度:Go代码 题目:二进制加法思路分析:复杂度:Go代码 幂运算专题题目&a…

购车小记:辅助驾驶(锋兰达双擎领先版(14W落地)/锐放双擎先锋版(心里预期13W落地))

文章目录 引言I 试驾L2辅助驾驶II 优惠2.1 补贴2.2 坚持免息2.3 礼包III 车型对比3.1 锐放双擎先锋3.2 锋兰达双擎领先版引言 最近想买辆代步车,关注了锐放、锋兰达。 记录下心得。 流程:多家店对比落地价、礼包、政府补贴;合同没有确定不交意向金。 不要因为价格优惠/政府…

WebRTC音视频通话-iOS端调用ossrs直播拉流

WebRTC音视频通话-iOS端调用ossrs直播拉流 之前实现iOS端调用ossrs服务,文中提到了推流。没有写拉流流程,所以会用到文中的WebRTCClient。请详细查看:https://blog.csdn.net/gloryFlow/article/details/132257196 一、iOS播放端拉流效果 二…

量子非凡去广告接口

量子非凡去广告接口,免费发布,请各位正常调用,别恶意攻击 >>>https://videos.centos.chat/weisuan.php/?url

如何关闭“若要接收后续google chrome更新,您需使用windows10或更高版本”

Windows Registry Editor Version 5.00[HKEY_CURRENT_USER\Software\Policies\Google\Chrome] "SuppressUnsupportedOSWarning"dword:00000001 如何关闭“若要接收后续 google chrome 更新,您需使用 windows 10 或更高版本” - 知乎

latex 笔记:cs论文需要的排版格式

主要针对英文文献 1 基本环境 连字符 不同长度的"-"表示不同含义。 一个"-"长度的连字符用于词中两个"-"长度的连字符常用于制定范围三个"-"长度的连字符是破折号数学中的负数要用数学环境下的-得到 强调 在正式文章中, 通常不…

派森 #P125. 寻找反素数

描述 反素数,英文称作 emirp(prime(素数)的左右颠倒拼写),是素数的一种。‪‬‪‬‪‬‪‬‪‬‮‬‪‬‭‬‪‬‪‬‪‬‪‬‪‬‮‬‪‬‭‬‪‬‪‬‪‬‪‬‪‬‮‬‭‬‫‬‪‬‪‬‪‬‪‬‪‬‮‬‭…

uboot使用

一、uboot模式 自启动模式 uboot启动后若没有用户介入,倒计时结束后会自动执行自启动环境变量(bootcmd)中设置的命令(一般作加载和启动内核) 交互模式 倒计时结束之前按下任意按键uboot会进入交互模式,交互模式下用户可输入ub…