Swin transformer 论文阅读记录 代码分析

ops/2024/12/23 11:17:23/

该篇文章,是我解析 Swin transformer 论文原理(结合pytorch版本代码)所记,图片来源于源paper或其他相应博客。

代码也非原始代码,而是从代码里摘出来的片段,配上简单数据,以便理解。

当然,也可能因为设置数据不当,造成误解,请多指教。

刚写了一部分。先发布。希望多多指正。


在这里插入图片描述
Figure 1.
(a) The proposed Swin Transformer builds hierarchical feature maps by merging image patches (shown in gray) in deeper layers ,
and has linear computation complexity to input image size due to computation of self-attention only within each local window (shown in red).
It can thus serve as a general-purpose backbone for both image classification and dense recognition tasks.
(b) In contrast, previous vision Transformers produce feature maps of a single low resolution and have quadratic computation complexity to input image size due to computation of self attention globally.

模型结构图

在这里插入图片描述
Figure 3.
(a) The architecture of a Swin Transformer (Swin-T);
(b) two successive Swin Transformer Blocks (notation presented with Eq. (3)).
W-MSA and SW-MSA are multi-head self attention modules with regular and shifted windowing configurations, respectively.

Stage 1 – Patch Embedding

It first splits an input RGB image into non-overlapping patches by a patch splitting module, like ViT.

Each patch is treated as a “token” and its feature is set as a concatenation of the raw pixel RGB values.

In our implementation, we use a patch size of 4×4 and thus the feature dimension of each patch is 4×4×3 = 48.(channel–3)

A linear embedding layer is applied on this raw-valued feature to project it to an arbitrary dimension (denoted as C).
这个表述,linear embedding layer,我感觉不太准确,但是,后半部分比较准确,哈哈,将channel–3变成了96.

Several Transformer blocks with modified self-attention computation (Swin Transformer blocks) are applied on these patch tokens.

The Transformer blocks maintain the number of tokens (H/4 × W/4), and together with the linear embedding are referred to as “Stage 1”.

代码

以下代码来自于model.py:

class PatchEmbed(nn.Module):"""2D Image to Patch Embedding"""
"""
@ time : 2024/12/17
"""
import torch
from PIL import Image
from torchvision import transforms
import matplotlib.pyplot as plt
import torch.nn as nn
import torch.nn.functional as Fclass PatchEmbed(nn.Module):"""2D Image to Patch Embedding"""def __init__(self, patch_size=4, in_c=3, embed_dim=96, norm_layer=None):super().__init__()patch_size = (patch_size, patch_size)self.patch_size = patch_sizeself.in_chans = in_cself.embed_dim = embed_dimself.proj = nn.Conv2d(in_c, embed_dim, kernel_size=patch_size, stride=patch_size)self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()def forward(self, x):_, _, H, W = x.shape# padding# 如果输入图片的H,W不是patch_size的整数倍,需要进行paddingpad_input = (H % self.patch_size[0] != 0) or (W % self.patch_size[1] != 0)if pad_input:# to pad the last 3 dimensions,# (W_left,W_right, H_top,H_bottom, C_front,C_back)x = F.pad(x,(0, self.patch_size[1] - W % self.patch_size[1],0, self.patch_size[0] - H % self.patch_size[0],0, 0))# 下采样patch_size倍x = self.proj(x)_, _, H, W = x.shape# flatten: [B, C, H, W] -> [B, C, HW]# transpose: [B, C, HW] -> [B, HW, C]x = x.flatten(2).transpose(1, 2)x = self.norm(x)print(x.shape)# torch.Size([1, 3136, 96])# 224/4 * 224/4 = 3136return x, H, Wif __name__ == '__main__':img_path = "tulips.jpg"img = Image.open(img_path)plt.imshow(img)# [N, C, H, W]print(img.size)# (500,375)#img_size = 224data_transform = transforms.Compose([transforms.Resize(int(img_size * 1.14)),transforms.CenterCrop(img_size),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])img = data_transform(img)print(img.shape)# torch.Size([3, 224, 224])# expand batch dimensionimg = torch.unsqueeze(img, dim=0)print(img.shape)# torch.Size([1, 3, 224, 224])# split image into non-overlapping patchespatch_embed = PatchEmbed(norm_layer=nn.LayerNorm)patch_embed(img)

Stage 2 – 3.2. Shifted Window based Self-Attention

Shifted window partitioning in successive blocks

The window-based self-attention module lacks connections across windows, which limits its modeling power.

To introduce cross-window connections while maintaining the efficient computation of non-overlapping windows,
we propose a shifted window partitioning approach which alternates between two partitioning configurations in consecutive Swin Transformer blocks.
为了在保持非重叠窗口高效计算的同时引入跨窗口连接,我们提出了一种移位窗口划分方法,该方法在连续的Swin Transformer块中交替使用两种不同的划分配置。

在这里插入图片描述
Figure 2.
In layer l (left), a regular window partitioning scheme is adopted, and self-attention is computed within each window.
In the next layer l + 1 (right), the window partitioning is shifted, resulting in new windows.
The self-attention computation in the new windows crosses the boundaries of the previous windows in layer l, providing connections among them.
在新窗口中进行的自注意力计算跨越了第l层中先前窗口的边界,从而在它们之间建立了连接。

Efficient batch computation for shifted configuration

An issue with shifted window partitioning is that it will result in more windows, and some of the windows will be smaller than M×M.

Here, we propose a more efficient batch computation approach by cyclic-shifting toward the top-left direction(向左上方向循环移动), as illustrated in Figure 4.

这里的 more efficient,是说相对于直观方法 padding—mask来说:

A naive solution is to pad the smaller windows to a size of M×M and mask out the padded values when computing attention.


在这里插入图片描述
Figure 4. Illustration of an efficient batch computation approach for self-attention in shifted window partitioning.


After this shift, a batched window may be composed of several sub-windows that are not adjacent in the feature map, so a masking mechanism is employed to limit self-attention computation to within each sub-window.
在此转换之后,批处理窗口可能由特征图中不相邻的几个子窗口组成,因此采用掩蔽机制将自注意力计算限制在每个子窗口内。

With the cyclic-shift, the number of batched windows remains the same as that of regular window partitioning, and thus is also efficient.
通过循环移位,批处理窗口的数量与常规窗口分区的数量保持不变,因此也是高效的。


上图和叙述,并不太直观,找了相关资料,一起分析:

在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
移动完成之后,4是一个单独区域,5、3为一组,7、1为一组,8、6、2、0为一组。

但,5、3本身是两个图像的边缘,混在一起计算不是乱了吗?一起计算也没问题,ViT也是全局计算的。

但,Swin-Transformer为了防止这个问题,在代码中使用了masked MSA,这样就能够通过设置蒙板来隔绝不同区域的信息了。

源码中具体的方法就是将不计算的位置元素减去100。

这里需要注意的是,在窗口数据进行滑动完之后,需要将数据还原回去,即挪回到原来的位置上。

代码

以下代码来自于model.py:

def window_partition(x, window_size: int):"""将feature map按照window_size划分成一个个没有重叠的window主要思路是将feature转成 (num_windows*B, window_size*window_size, C)的shape,把需要self-attn计算的window排列到第0维,一次并行的qkv就可以了Args:x: (B, H, W, C)window_size (int): window size(M)Returns:windows: (num_windows*B, window_size, window_size, C)"""B, H, W, C = x.shape# B,224,224,C# B,56,56,Cx = x.view(B, H // window_size, window_size, W // window_size, window_size, C)# B,32,7,32,7,C# B,8,7,8,7,C# permute:# [B, H//Mh, Mh,    W//Mw, Mw, C] -># [B, H//Mh, W//Mh, Mw,    Mw, C]# B,32,32,7,7,C# B,8,8,7,7,C# view:# [B, H//Mh, W//Mw, Mh, Mw, C] -># [B*num_windows,   Mh, Mw, C]# B*1024,7,7,C# B*64,7,7,C# 32*32 = 1024# 224 / 7 = 32windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)return windows

分析:将 [B, C, 56, 56] 最后变成了[64B, C, 7, 7],原先的 B*C 张 56*56 的特征图,最后变成了 B*64*C张7*7的特征;

即,我们有64B个样本,每个样本包含C个7x7的通道。

注意,window_size–M–7,是每个window的大小,7*7,不是7*7个window,我刚开始混淆了这一点。


class BasicLayer(nn.Module):# A basic Swin Transformer layer for one stage.def __init__(self, dim, depth, num_heads, window_size,mlp_ratio=4., qkv_bias=True, drop=0., attn_drop=0.,drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False):super().__init__()self.dim = dimself.depth = depthself.window_size = window_sizeself.use_checkpoint = use_checkpointself.shift_size = window_size // 2# 7//2 = 3# build blocksself.blocks = nn.ModuleList([SwinTransformerBlock(dim=dim,num_heads=num_heads,window_size=window_size,shift_size=0 if (i % 2 == 0) else self.shift_size,mlp_ratio=mlp_ratio,qkv_bias=qkv_bias,drop=drop,attn_drop=attn_drop,drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,norm_layer=norm_layer)for i in range(depth)])...# depth: 2, 2, 6, 2# 即,第一层,depth=2, 有两个SwinTransformerBlock,shift_size分别为:0,3# 即,第二层,depth=2, 有两个SwinTransformerBlock,shift_size分别为:0,3# 即,第三层,depth=6, 有两个SwinTransformerBlock,shift_size分别为:#	0,3,0,3,0,3# 即,第四层,depth=2, 有两个SwinTransformerBlock,shift_size分别为:0,3def create_mask(self, x, H, W):# calculate attention mask for SW-MSA
import numpy as np
import torchH = 7
W = 7
window_size = 7
shift_size = 3Hp = int(np.ceil(H / window_size)) * window_size
Wp = int(np.ceil(W / window_size)) * window_size# 拥有和feature map一样的通道排列顺序,方便后续window_partition
img_mask = torch.zeros((1, Hp, Wp, 1))
# [1, Hp, Wp, 1]
print(img_mask, '\n')h_slices = (slice(0, -window_size),slice(-window_size, -shift_size),slice(-shift_size, None)
)
print(h_slices, '\n')
# (slice(0, -7, None), slice(-7, -3, None), slice(-3, None, None))w_slices = (slice(0, -window_size),slice(-window_size, -shift_size),slice(-shift_size, None)
)
print(w_slices, '\n')
# (slice(0, -7, None), slice(-7, -3, None), slice(-3, None, None))cnt = 0
for h in h_slices:for w in w_slices:img_mask[:, h, w, :] = cntcnt += 1print(img_mask)

在这里插入图片描述

import torchimg_mask = torch.rand((2, 3))
print(img_mask)
'''
tensor([[0.7410, 0.6020, 0.5195],[0.9214, 0.2777, 0.8418]])
'''
attn_mask = img_mask.unsqueeze(1) - img_mask.unsqueeze(2)
print(attn_mask)
'''
tensor([[[ 0.0000, -0.1390, -0.2215],[ 0.1390,  0.0000, -0.0825],[ 0.2215,  0.0825,  0.0000]],[[ 0.0000, -0.6437, -0.0796],[ 0.6437,  0.0000,  0.5642],[ 0.0796, -0.5642,  0.0000]]])
'''print(img_mask.unsqueeze(1))
'''
tensor([[[0.7410, 0.6020, 0.5195]],[[0.9214, 0.2777, 0.8418]]])
'''
print(img_mask.unsqueeze(2))
'''
tensor([[[0.7410],[0.6020],[0.5195]],[[0.9214],[0.2777],[0.8418]]])
'''

上面那个代码,需要根据下面这个代码对应着走,shift_size–torch.roll()

class SwinTransformerBlock(nn.Module):# Swin Transformer Block....def forward(self, x, attn_mask):H, W = self.H, self.WB, L, C = x.shapeassert L == H * W, "input feature has wrong size"shortcut = xx = self.norm1(x)x = x.view(B, H, W, C)# pad feature maps to multiples of window size# 把feature map给pad到window size的整数倍pad_l = pad_t = 0pad_r = (self.window_size - W % self.window_size) % self.window_sizepad_b = (self.window_size - H % self.window_size) % self.window_size# 注意F.pad的顺序,刚好是反着来的, 例如:# x.shape = (b, h, w, c)# x = F.pad(x, (1, 1, 2, 2, 3, 3))# x.shape = (b, h+6, w+4, c+2)# 源码可能有误,修改成下面的# x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b))x = F.pad(x, (0, 0, pad_t, pad_b, pad_l, pad_r))_, Hp, Wp, _ = x.shape# cyclic shiftif self.shift_size > 0:# paper中,滑动的size是窗口大小的/2(向下取整)# torch.roll以H,W的维度为例子,负值往左上移动,正值往右下移动。# 溢出的值在对角方向出现。即循环移动。shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))else:shifted_x = xattn_mask = None# partition windowsx_windows = window_partition(shifted_x, self.window_size)  # [nW*B, Mh, Mw, C]x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # [nW*B, Mh*Mw, C]# W-MSA/SW-MSAattn_windows = self.attn(x_windows, mask=attn_mask)  # [nW*B, Mh*Mw, C]...

其中,torch.roll()方法简易示例如下:

import torchx = torch.randn(1, 4, 4, 3)
print(x, '\n')shifted_x = torch.roll(x, shifts=(-3, -3), dims=(1, 2))
print(shifted_x, '\n')

为了方便理解,我更换了维度:

import torchx = torch.randn(1, 3, 7, 7)
print(x, '\n')shifted_x = torch.roll(x, shifts=(-3, -3), dims=(2, 3))
print(shifted_x, '\n')

在这里插入图片描述

Relative position bias


Relative Position Bias通过给自注意力机制的输出加上一个与token相对位置相关的偏置项,从而增强了模型对局部和全局信息的捕捉能力。

实现方式:

1、构建相对位置索引:

  • 首先,需要确定一个 window size ,并在该窗口内计算token之间的相对位置。
  • 通过构建相对位置索引表(relative position index table),可以方便地查询任意两个token之间的相对位置。

2、可学习的偏置表:

  • 初始化一个与相对位置索引表大小相同的可学习参数表(relative position bias table),这些参数在训练过程中会被优化。
  • 根据相对位置索引,从偏置表中查询对应的偏置值,并将其加到自注意力机制的输出上。

3、计算过程:

  • 在自注意力机制的计算中,通常会将 Q、K 和 V 进行矩阵乘法运算,得到注意力得分。
  • 然后,将 Relative Position Bias 加到注意力得分上,再进行 softmax 运算,最后与 V 相乘得到最终的输出。

代码

import torchcoords_h = torch.arange(7)
coords_w = torch.arange(7)a, b = torch.meshgrid([coords_h, coords_w], indexing="ij")coords = torch.stack(torch.meshgrid([coords_h, coords_w], indexing="ij"))coords_flatten = torch.flatten(coords, 1)
# [2, Mh*Mw]
# print(coords_flatten)
'''
tensor([
[0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3,3, 3, 3, 3, 4, 4, 4, 4, 4, 4, 4, 5, 5, 5, 5, 5, 5, 5, 6, 6, 6, 6, 6, 6, 6],
[0, 1, 2, 3, 4, 5, 6, 0, 1, 2, 3, 4, 5, 6, 0, 1, 2, 3, 4, 5, 6, 0, 1, 2,3, 4, 5, 6, 0, 1, 2, 3, 4, 5, 6, 0, 1, 2, 3, 4, 5, 6, 0, 1, 2, 3, 4, 5, 6]
])
'''
# [2, Mh*Mw, 1] - [2, 1, Mh*Mw]
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]
relative_coords = relative_coords.permute(1, 2, 0).contiguous()
# [Mh*Mw, Mh*Mw, 2]
'''
tensor([[[ 0,  0],[ 0, -1],[ 0, -2],...,[-6, -4],[-6, -5],[-6, -6]],[[ 0,  1],[ 0,  0],[ 0, -1],...,[-6, -3],[-6, -4],[-6, -5]],[[ 0,  2],[ 0,  1],[ 0,  0],...,[-6, -2],[-6, -3],[-6, -4]],...,[[ 6,  4],[ 6,  3],[ 6,  2],...,[ 0,  0],[ 0, -1],[ 0, -2]],[[ 6,  5],[ 6,  4],[ 6,  3],...,[ 0,  1],[ 0,  0],[ 0, -1]],[[ 6,  6],[ 6,  5],[ 6,  4],...,[ 0,  2],[ 0,  1],[ 0,  0]]])
'''
relative_coords[:, :, 0] += 6
# shift to start from 0relative_coords[:, :, 1] += 6
relative_coords[:, :, 0] *= 13relative_position_index = relative_coords.sum(-1)
print(relative_position_index.shape)
# torch.Size([49, 49])# print(relative_position_index)
'''
tensor([[ 84,  83,  82,  ...,   2,   1,   0],[ 85,  84,  83,  ...,   3,   2,   1],[ 86,  85,  84,  ...,   4,   3,   2],...,[166, 165, 164,  ...,  84,  83,  82],[167, 166, 165,  ...,  85,  84,  83],[168, 167, 166,  ...,  86,  85,  84]])
'''

其中,

relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] 

这行代码是用来计算一个二维坐标点集 coords_flatten 中所有点对之间的相对坐标(或位移)。

未对行列加乘操作之前的矩阵,relative_coords 是一个形状为 (N, N, 2) 的数组,其中 relative_coords[i, j, :] 表示从点 i 到点 j 的相对坐标(或位移)。

在这里插入图片描述

结合其他博客的分析:

如图,假设我们现在有一个window-size=2的feature map;
这里面如果用绝对位置来表示位置索引;
然后如果用相对位置表示,就会有4个情况,但分别都是以自己为(0, 0)计算其他token的相对位置。
分别把4个相对位置展开,得到4x4的矩阵,如最下的矩阵所示。
在这里插入图片描述
请注意这里说的都是位置索引,并不是最后的位置编码。因为后面我们会根据相对位置索引去取对应位置的参数。取出来的值才是相对位置编码。
源码中,作者还将二维索引给转成了一维索引。如果直接将行列相加,就变成一维了。但这样(0, 1)和(1, 0)得到的结果都是1,这样肯定不行。来看看源码的做法怎么做的:
首先,所有行、列都加上M-1,其次将所有的行索引乘上2M-1
最后行索引和列索引相加,保证了相对位置关系,也不会出现0+1 = 1+0 的现象了。
在这里插入图片描述

刚刚也说了,之前计算的是相对位置索引,并不是实际位置偏执参数。

真正使用到的数值需要从relative position bias table,这个表的长度是等于(2M-1)X(2M-1)的。在代码中它是一个可学习参数。
在这里插入图片描述

import torch
from torch import nnwindow_size = (7, 7)
num_heads = 3relative_position_bias_table = nn.Parameter(torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)
)
print(relative_position_bias_table)
......
nn.init.trunc_normal_(self.relative_position_bias_table, std=.02)

Stage 3 – patch merging layers

To produce a hierarchical representation, the number of tokens is reduced by patch merging layers as the network gets deeper.

The first patch merging layer concatenates the features of each group of 2×2 neighboring patches, and applies a linear layer on the 4C-dimensional concatenated features.
首个补丁合并层将每组2×2相邻补丁的特征进行拼接,并在拼接后的4C维特征上应用一个线性层。

This reduces the number of tokens by a multiple of 2×2=4(2 ×downsampling of resolution), and the output dimension is set to 2C.

Swin Transformer blocks are applied afterwards for feature transformation, with the resolution kept at H/8 × W/8.

同样,结合其他大神分析,图展示如下:

在这里插入图片描述

Related Work

Self-attention based backbone architectures

Instead of using sliding windows, we propose to shift windows between consecutive layers, which allows for a more efficient implementation in general hardware.

。。。。。

Cited link or paper name

  1. Swin Transformer: Hierarchical Vision Transformer using Shifted Windows.
  2. https://blog.csdn.net/weixin_42392454/article/details/141395092

http://www.ppmy.cn/ops/144293.html

相关文章

CSS 网络安全字体

适用于 HTML 和 CSS 的最佳 Web 安全字体 下面列出了适用于 HTM L和 CSS 的最佳 Web 安全字体: Arial (sans-serif)Verdana (sans-serif)Helvetica (sans-serif)Tahoma (sans-serif)Trebuchet MS (sans-serif)Times New Roman (serif)Georgia (serif)Garamond (se…

【Java 马踏棋盘算法】韩顺平笔记

骑士周游算法 算法优化意义 1.算法是程序的灵魂,为什么有些程序可以在海量数据计算时,依然保持高速计算? 2.在Unix下开发服务器程序,功能是要支持上干万人同时在线,在上线前,做内测,一切OK,可上…

如何正确计算显示器带宽需求

1. 对显示器的基本认识 一个显示器的参数主要有这些: 分辨率:显示器屏幕上像素点的总数,通常用横向像素和纵向像素的数量来表示,比如19201080(即1080p)。 刷新率:显示器每秒钟画面更新的次数&…

【JavaEE初阶】JavaScript相应的WebAPI

目录 🌲WebAPI 背景知识 🚩什么是 WebAPI 🚩什么是 API 🎍DOM 基本概念 🚩什么是 DOM 🚩DOM 树 🍀获取元素 🚩querySelector 🚩querySelectorAll &#x1f384…

C语言经典100例

文章目录 前言123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525355565859606162636465 前言 以下题目大部分来自于C语言经典100例 1 题目:有1、2、3、4个数字,能组成多少个互不相同且无重复数字的…

Spring Boot 教程之三十六:实现身份验证

如何在 Spring Boot 中实现简单的身份验证? 在本文中,我们将学习如何使用 Spring设置和配置基本身份验证。身份验证是任何类型的安全性中的主要步骤之一。Spring 提供依赖项,即Spring Security,可帮助在 API 上建立身份验证。有很…

webpack最基础的配置

以下是一个基本的webpack配置示例,它包括了入口文件、输出配置以及模式设置(开发模式或生产模式)。 const path require(path);module.exports (env, argv) > {// 根据传入的参数确定是开发模式还是生产模式const isProduction argv.m…

[创业之路-199]:《华为战略管理法-DSTE实战体系》- 3 - 价值转移理论与利润区理论

目录 一、价值转移理论 1.1. 什么是价值? 1.2. 什么价值创造 (1)、定义 (2)、影响价值创造的因素 (3)、价值创造的三个过程 (4)、价值创造的实践 (5&…