目前Transformer应用到图像领域的挑战:
- 图像分辨率高,像素点多,如果需要更多特征就必须构建很长的序列,但Transformer基于全局自注意力的计算导致计算量较大,能否用
窗口+分层
的形式代替长序列,实现类似CNN感受野的效果?
针对上述问题,我们提出了一种包含滑窗操作,具有层级设计的Swin Transformer,逐层合并tokens。
其中滑窗操作包括不重叠的local window + 重叠的cross-window。将注意力计算限制在一个窗口中,一方面能引入CNN卷积操作的局部性,另一方面能节省计算量
。
1. SwinTransformer总体架构
整个模型采取层次化的设计,一共包含4个Stage,每个stage都会缩小输入特征图的分辨率,像CNN一样逐层扩大感受野。
- 在输入开始的时候,做了一个Patch Embedding,将图片切成一个个图块(对image进行卷积,然后对特征图切分为patch),并嵌入到Embedding,构建token序列。
- 在每个Stage里,由Patch Merging和多个Block组成。
- 其中Patch Merging模块主要在每个Stage一开始进行下采样(W和H不断减小,C不断增大),降低图片分辨率。
- 而Block具体结构如右图所示,主要是
LayerNorm
,MLP
,Window Attention
和Shifted Window Attention
组成 (提供了2种attention计算方法)
class SwinTransformer(nn.Module):def __init__(...):super().__init__()...# absolute position embeddingif self.ape:self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))self.pos_drop = nn.Dropout(p=drop_rate)# build layersself.layers = nn.ModuleList()for i_layer in range(self.num_layers):layer = BasicLayer(...)self.layers.append(layer)self.norm = norm_layer(self.num_features)self.avgpool = nn.AdaptiveAvgPool1d(1)self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()def forward_features(self, x):
# step1:Patch Embeddingx = self.patch_embed(x)if self.ape:x = x + self.absolute_pos_embedx = self.pos_drop(x)# step2:BasicLayer = feature_shift + Window Partition + W-MSA/SW-MSA + Window Reverse + reverse_shift + Patch Mergingfor layer in self.layers: # 遍历4个stagex = layer(x) # step3:LN + AvgPool + flattenx = self.norm(x) # (B, L, C)->(B, L, C)x = self.avgpool(x.transpose(1, 2)) # (B, L, C)->(B, C, 1)x = torch.flatten(x, 1) # (B, C, 1)->(B, C)return xdef forward(self, x):x = self.forward_features(x)
# step4:FC(不同任务的Head层不同)x = self.head(x) # # (B, C)->(B, num_class)return x
其中有几个地方处理方法与ViT不同:
- ViT在输入会给embedding进行位置编码。而Swin-T这里则是作为一个可选项(self.ape),Swin-T是在计算Attention的时候做了一个相对位置编码
- ViT会单独加上一个可学习参数,作为分类的token。而Swin-T则是直接做平均,输出分类,有点类似CNN最后的全局平均池化层
1.1 Patch Embedding
在输入进Block前,我们需要将图片切成一个个patch,然后嵌入向量。
具体做法是对原始图片(224,224,3)裁成一个个 patch_size * patch_size
的窗口大小,然后进行嵌入。
这里可以将stride=4,kernel_size=4设置为patch_size=4大小,按照VIT中patch embedding的方式(不重叠卷积)得到每一个图像块patch对应长度为embed_dim的向量。设定输出通道来确定嵌入向量的大小。最后将H,W维度展开,并移动到第一维度。输出(3136, 96)相当于3136个长度为96的token,j将tokens序列排列为正方形即(56*56, 96)
。
import torch
import torch.nn as nnclass PatchEmbed(nn.Module):def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):super().__init__()img_size = to_2tuple(img_size) # -> (img_size, img_size)patch_size = to_2tuple(patch_size) # -> (patch_size, patch_size)patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]self.img_size = img_sizeself.patch_size = patch_sizeself.patches_resolution = patches_resolutionself.num_patches = patches_resolution[0] * patches_resolution[1]self.in_chans = in_chansself.embed_dim = embed_dimself.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) # 这里!!if norm_layer is not None:self.norm = norm_layer(embed_dim)else:self.norm = Nonedef forward(self, x):# 假设采取默认参数x = self.proj(x) # 出来的是(N, 96, 224/4, 224/4) x = torch.flatten(x, 2) # 把HW维展开,(N, 96, 56*56)x = torch.transpose(x, 1, 2) # 把通道维放到最后 (N, 56*56, 96)if self.norm is not None:x = self.norm(x)return x
1.2 Window Partition/Reverse
window partition函数是用于对张量按非重叠窗口大学window_size划分为一条条tokens,指定窗口大小。将原本的张量从 N H W C
, 划分成 num_windows*B, window_size, window_size, C
,其中 num_windows = H*W / (window_size*window_size)
,即窗口的个数。
如输入特征图(56,56,96),默认window_size=7x7,所以分为8x8个窗口,num_windows=64,输出特征图(64, 7, 7, 96),之前的单位是token(共56x56=3136个token),现在的单位是窗口(共8x8=64个window,每个window聚集了7x7=49个token),最后把每个window内的token聚合展平为一个大token,每个大token的shape=(49,96)
。
而window reverse函数则是对应的逆过程。这两个函数会在后面的Window Attention用到。
实现起来,window partition和window reverse没有可学习参数,因而不需要继承其他的类,写成函数就行。上面windows_partition是将送进来的特征进行window_size的划分,最终变为一条条tokens(对应示意图!!!)
def window_partition(x, window_size):B, H, W, C = x.shapex = x.view(B, H // window_size, window_size, W // window_size, window_size, C)windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)return windowsdef window_reverse(windows, window_size, H, W):B = int(windows.shape[0] / (H * W / window_size / window_size))x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)return x
1.3 W-MSA 和 SW-MSA
两者串联起来就是一个Swin Transformer Block:
- W-MSA 窗口多头自注意力机制(windows multi-head self attention):窗口内部multi-head self-attention
- SW-MSA 滑动窗口多头自注意力机制(shift windows multi-head self attention):窗口之间multi-head self-attention
W-MSA
传统的Transformer都是基于全局来计算注意力的,因此计算复杂度十分高。而Swin Transformer则将注意力的计算限制在每个窗口内,进而减少了计算量。
输入特征图(64, 7, 7, 96),window size=7(包含7x7个长度96的token),共64个窗口。
swin transformer是按照window size内的小方格计算self-attention的,比如上图中的windows size=7,也就是每7*7个tokens(红色框)之间计算多头self-attention(head=3)。
一次性用x计算出qkv三个矩阵:3个qkv矩阵放在一起的shape=(3, 64, 3, 49, 32),3个矩阵,64个window,head=3, 窗口大小=7x7=49,每个head特征长度96/3=32
,64个窗口自己的attention结果是(64, 3, 49, 49)。
这里注意,计算self-attention的输入tokens的数量和维度都不变换,因此最终的输出特征图依旧是(64, 49, 96),64个窗口,每个窗口7x7个token,每个96维的token都会学习到了窗口内的自注意力。
SW-MSA
前面的Window Attention是在每个窗口下计算注意力的,为了更好的和其他window进行信息交互,Swin Transformer还引入了shifted window操作。
左边是没有重叠的Window Attention,而右边则是将窗口进行移位的Shift Window Attention。可以看到移位后的窗口包含了原本相邻窗口的元素。但这也引入了一个新问题,即window的个数翻倍了,由原本4个窗口变成了9个窗口。
在实际代码里,我们是通过对特征图移位,并给Attention设置mask
来间接实现的。能在保持原有的window个数下,最后的计算结果等价。
特征图移位+Mask操作
对特征图位移(torch.roll)之后,还是按照4个窗口计算attention,但是会有冗余计算结果,直接设置对应位置mask为负无穷(softmax后为0),忽略不需要的attetion部分(图中灰色部分),输出的结果同W-MSA 也是(56, 56, 96,不要忘记计算完对特征图还原平移)。
我们看下Block的前向代码:
def forward(self, x):H, W = self.input_resolutionB, L, C = x.shapeassert L == H * W, "input feature has wrong size"shortcut = xx = self.norm1(x)x = x.view(B, H, W, C)# cyclic shiftif self.shift_size > 0:shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))else:shifted_x = x# partition windowsx_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, Cx_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C# W-MSA/SW-MSAattn_windows = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C# merge windowsattn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C# reverse cyclic shiftif self.shift_size > 0:x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))else:x = shifted_xx = x.view(B, H * W, C)# FFNx = shortcut + self.drop_path(x)x = x + self.drop_path(self.mlp(self.norm2(x)))return x
整体流程如下
- 先对特征图进行LayerNorm
- 通过self.shift_size决定是否需要对特征图进行shift
- 然后将特征图切成一个个窗口
- 计算Attention,通过self.attn_mask来区分Window Attention还是Shift Window Attention
- 将各个窗口合并回来
- 如果之前有做shift操作,此时进行reverse shift,把之前的shift操作恢复
- 做dropout和残差连接
- 再通过一层LayerNorm+全连接层,以及dropout和残差连接
1.4 Patch Merging
该模块的作用是在每个Stage开始前做降采样,用于缩小分辨率,调整通道数 进而形成层次化的设计,同时也能节省一定运算量。
在CNN中,则是在每个Stage开始前用stride=2的卷积/池化层来降低分辨率。
每次降采样是两倍,因此在行方向和列方向上,间隔2选取元素。
然后拼接在一起作为一整个张量,最后展开。此时通道维度会变成原先的4倍(因为H,W各缩小2倍),此时再通过一个全连接层再调整通道维度为原来的两倍。如输入(56, 56, c),变为(28, 28, 4c),全连接输出(28, 28, 2c),这样就使得下一个stage的窗口数量减少了。
class PatchMerging(nn.Module):def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm):super().__init__()self.input_resolution = input_resolutionself.dim = dimself.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)self.norm = norm_layer(4 * dim)def forward(self, x):"""x: B, H*W, C"""H, W = self.input_resolutionB, L, C = x.shapeassert L == H * W, "input feature has wrong size"assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even."x = x.view(B, H, W, C)x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 Cx1 = x[:, 1::2, 0::2, :] # B H/2 W/2 Cx2 = x[:, 0::2, 1::2, :] # B H/2 W/2 Cx3 = x[:, 1::2, 1::2, :] # B H/2 W/2 Cx = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*Cx = x.view(B, -1, 4 * C) # B H/2*W/2 4*Cx = self.norm(x)x = self.reduction(x)return x
下面是一个示意图(输入张量B=1, H=W=8, C=1,不包含最后的全连接层调整)
2. 实验分析
在ImageNet22K数据集上,准确率能达到惊人的86.4%。另外在检测,分割等任务上表现也很优异。这篇文章创新点很棒,引入window这一个概念,将CNN的局部性引入,还能控制模型整体计算量。在Shift Window Attention部分,用一个mask和移位操作,很巧妙的实现计算等价。作者的代码也写得十分赏心悦目,推荐阅读!