文章目录
- 摘要
- Abstract
- 1 Swin Transformer
- 1.1 输入
- 1.2 Patch Partition
- 1.3 Linear Embedding
- 1.4 Patch Merging
- 1.5 Swin Transformer Block
- 1.6 代码
- 总结
摘要
本篇博客介绍了采用类似于卷积核的移动窗口进行图像特征提取的Swin Transformer网络模型,这是一种基于Transformer的计算机视觉模型,通过引入类似CNN卷积核的Shifted Windows(移动窗口),其通过将 self-attention 计算限制在不重叠的局部窗口上,由于窗口大小固定,每个窗口内计算复杂度固定,计算复杂度跟图像大小呈线性关系,从而带来更高的效率。Swin Transformer Block的核心机制,包括W-MSA(窗口多头自注意力)和SW-MSA(移动窗口多头自注意力),这些机制从全局的自注意力变为了窗口内的自注意力,使得模型在减少计算量的同时捕捉多尺度特征。最后,通过使用预训练模型对花类数据集进行微调;通过预测代码和结果展示了模型对图像分类的高准确率。
Abstract
This blog post introduces the Swin Transformer network model, which adopts shifting windows similar to convolutional kernels for image feature extraction. This computer vision model, based on the Transformer architecture, introduces Shifted Windows akin to CNN convolutional kernels. By limiting self-attention computations to non-overlapping local windows, and with a fixed window size, the computational complexity within each window is constant, and the overall computational complexity is linear with respect to the image size, thus offering greater efficiency. The core mechanisms of the Swin Transformer Block include W-MSA (Window Multi-Head Self-Attention) and SW-MSA (Shifted Window Multi-Head Self-Attention), which transition from global self-attention to window-internal self-attention, allowing the model to capture multi-scale features while reducing computational load. Finally, the post demonstrates the model’s high accuracy in image classification by fine-tuning a pre-trained model on a flower dataset; the prediction code and results showcase the model’s high accuracy in image classification tasks.
1 Swin Transformer
Swin Transformer提出了类似CNN卷积核的Shifted Windows(移动窗口),以减少patch之间做自注意力的次数,达到节省时间的目的。
接下来我将参照Swin Transformer模型整体结构图来讲解该模型,网络结构图如下所示:
1.1 输入
以输入图像大小为 224x224 为例,将图像tensor 224x224x3 从1处传入模型作为开始。
1.2 Patch Partition
到达第2步Patch Partiton,因为是以Transformer为基础,如同ViT一样需要将图片打成patch传入模型。这里将图像大小 224x224 通过卷积核大小为 4x4,步伐为4的卷积将图像转换为 56x56 的大小,通道数从3变为3x4x4=48维。
1.3 Linear Embedding
接下来数据将传入Swin Transformer最重要的第3大部分,首先在第一个stage会经过全连接层,这里的通道数在论文中是设置好的,即C=96。如上图所示,56x56x48 的tensor经过Linear Embedding维度变化为56x56x96。
1.4 Patch Merging
这里先说一下Patch Merging,因为只要第一个stage先经过Linear Embedding,后3个stage都是Patch Merging。这里的Patch Merging类似于CNN中的池化,为了使窗口获得多尺度的特征,需要在每一个block之前进行下采样,这样在窗口大小不变的情况下图像变小、维度增加,可以获得更多的特征信息。如下图所示:
论文中提到将图像下采样2倍,在Patch Merging中会隔一个点采样。细心的人这时会发现,我们图像的维度从 4x4x1 变为了 2x2x4,但是在模型中是从变为,即大小缩小一半,维度扩大2倍。但是在上图的操作中维度扩大了4倍,论文中为了使其与CNN中池化保持一致,在Patch Merging之后会在再通过一次卷积使维度缩小一半,以达到大小缩小一半,维度扩大2倍的效果。
到这一步Swin Transformer已经拥有了感知多尺度特征的手段。
1.5 Swin Transformer Block
Swin Transformer中最重要的模块,这里是用到了Transformer中的编码器,但是做了一些改动,Transformer中是单纯的多头自注意力(MSA),而这里用到了W-MSA和SW-MSA。
- W-MSA
引入Windows Multi-head Self-Attention(窗口多头自注意力)就是为了解决开篇提到的计算量问题。W-MSA会将图像划分为 MxM (在论文中M默认为7)的不重叠窗口,然后多头自注意力只在同一个窗口内做,这样就从全局的自注意力变为了窗口内的自注意力,大大减小了计算量
如上公式,若图像长宽为 224x224 、M=7,则从 22 4 4 224^4 2244降为 22 4 4 × 49 224^4×49 2244×49
-
SW-MSA
如果只进行W-MSA会发现一个问题,窗口之间是没用通信的,也就是窗口与窗口之间是独立计算的,就不能像卷积一样移动去感受不同像素范围。于是,作者引入了Shifted Windows Multi-Head Self-Attention(SW-MSA)模块让窗口移动起来。
例如上图所示,将图像分为了4个窗口,在W-MSA中4个窗口分别做MSA;然后在下一层SW-MSA左上角的窗口会向右和向下移动个步伐。如右图所示将图像分割为9个窗口,会发现上一层不属于同一个窗口的patch,经过移动后融合为一个窗口,窗口与窗口之间成功进行了信息交流。
如果9个窗口分别进行SW-MSA,那么从上一层4个相同的窗口变为9个不完全相同的窗口进行计算,不但不利于并行计算,而且窗口数量也增大两倍多,通过W-MSA积攒的计算优势一去不复返。
那么该如何解决呢?有人想到padding补0,这样虽然可以保证窗口大小相同进行并行计算,但是数量是增多的。论文中,作者提出通过掩码的方式进行自注意力的计算
A、B、C经过顺时针循环位移,将图像窗口补为大小相同的4个窗口。但是同一个窗口中,颜色不同的部分不能直接进行自注意力,需要在计算之后分别乘上对应的掩码。
掩码是加上一个很大的负值,在经过softmax之后该部分就会变为0。1的部分全属于同一窗口可直接进行自注意力计算。
1.6 代码
Swin Transformer网络模型如下:
""" Swin Transformer
A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows`- https://arxiv.org/pdf/2103.14030
Code/weights from https://github.com/microsoft/Swin-Transformer
"""import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint as checkpoint
import numpy as np
from typing import Optionaldef drop_path_f(x, drop_prob: float = 0., training: bool = False):"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted forchanging the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use'survival rate' as the argument."""if drop_prob == 0. or not training:return xkeep_prob = 1 - drop_probshape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNetsrandom_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)random_tensor.floor_() # binarizeoutput = x.div(keep_prob) * random_tensorreturn outputclass DropPath(nn.Module):"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""def __init__(self, drop_prob=None):super(DropPath, self).__init__()self.drop_prob = drop_probdef forward(self, x):return drop_path_f(x, self.drop_prob, self.training)def window_partition(x, window_size: int):"""将feature map按照window_size划分成一个个没有重叠的windowArgs:x: (B, H, W, C)window_size (int): window size(M)Returns:windows: (num_windows*B, window_size, window_size, C)"""B, H, W, C = x.shapex = x.view(B, H // window_size, window_size, W // window_size, window_size, C)# permute: [B, H//Mh, Mh, W//Mw, Mw, C] -> [B, H//Mh, W//Mh, Mw, Mw, C]# view: [B, H//Mh, W//Mw, Mh, Mw, C] -> [B*num_windows, Mh, Mw, C]windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)return windowsdef window_reverse(windows, window_size: int, H: int, W: int):"""将一个个window还原成一个feature mapArgs:windows: (num_windows*B, window_size, window_size, C)window_size (int): Window size(M)H (int): Height of imageW (int): Width of imageReturns:x: (B, H, W, C)"""B = int(windows.shape[0] / (H * W / window_size / window_size))# view: [B*num_windows, Mh, Mw, C] -> [B, H//Mh, W//Mw, Mh, Mw, C]x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)# permute: [B, H//Mh, W//Mw, Mh, Mw, C] -> [B, H//Mh, Mh, W//Mw, Mw, C]# view: [B, H//Mh, Mh, W//Mw, Mw, C] -> [B, H, W, C]x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)return xclass PatchEmbed(nn.Module):"""2D Image to Patch Embedding"""def __init__(self, patch_size=4, in_c=3, embed_dim=96, norm_layer=None):super().__init__()patch_size = (patch_size, patch_size)self.patch_size = patch_sizeself.in_chans = in_cself.embed_dim = embed_dimself.proj = nn.Conv2d(in_c, embed_dim, kernel_size=patch_size, stride=patch_size)self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()def forward(self, x):_, _, H, W = x.shape# padding# 如果输入图片的H,W不是patch_size的整数倍,需要进行paddingpad_input = (H % self.patch_size[0] != 0) or (W % self.patch_size[1] != 0)if pad_input:# to pad the last 3 dimensions,# (W_left, W_right, H_top,H_bottom, C_front, C_back)x = F.pad(x, (0, self.patch_size[1] - W % self.patch_size[1],0, self.patch_size[0] - H % self.patch_size[0],0, 0))# 下采样patch_size倍x = self.proj(x)_, _, H, W = x.shape# flatten: [B, C, H, W] -> [B, C, HW]# transpose: [B, C, HW] -> [B, HW, C]x = x.flatten(2).transpose(1, 2)x = self.norm(x)return x, H, Wclass PatchMerging(nn.Module):r""" Patch Merging Layer.Args:dim (int): Number of input channels.norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm"""def __init__(self, dim, norm_layer=nn.LayerNorm):super().__init__()self.dim = dimself.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)self.norm = norm_layer(4 * dim)def forward(self, x, H, W):"""x: B, H*W, C"""B, L, C = x.shapeassert L == H * W, "input feature has wrong size"x = x.view(B, H, W, C)# padding# 如果输入feature map的H,W不是2的整数倍,需要进行paddingpad_input = (H % 2 == 1) or (W % 2 == 1)if pad_input:# to pad the last 3 dimensions, starting from the last dimension and moving forward.# (C_front, C_back, W_left, W_right, H_top, H_bottom)# 注意这里的Tensor通道是[B, H, W, C],所以会和官方文档有些不同x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2))x0 = x[:, 0::2, 0::2, :] # [B, H/2, W/2, C]x1 = x[:, 1::2, 0::2, :] # [B, H/2, W/2, C]x2 = x[:, 0::2, 1::2, :] # [B, H/2, W/2, C]x3 = x[:, 1::2, 1::2, :] # [B, H/2, W/2, C]x = torch.cat([x0, x1, x2, x3], -1) # [B, H/2, W/2, 4*C]x = x.view(B, -1, 4 * C) # [B, H/2*W/2, 4*C]x = self.norm(x)x = self.reduction(x) # [B, H/2*W/2, 2*C]return xclass Mlp(nn.Module):""" MLP as used in Vision Transformer, MLP-Mixer and related networks"""def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):super().__init__()out_features = out_features or in_featureshidden_features = hidden_features or in_featuresself.fc1 = nn.Linear(in_features, hidden_features)self.act = act_layer()self.drop1 = nn.Dropout(drop)self.fc2 = nn.Linear(hidden_features, out_features)self.drop2 = nn.Dropout(drop)def forward(self, x):x = self.fc1(x)x = self.act(x)x = self.drop1(x)x = self.fc2(x)x = self.drop2(x)return xclass WindowAttention(nn.Module):r""" Window based multi-head self attention (W-MSA) module with relative position bias.It supports both of shifted and non-shifted window.Args:dim (int): Number of input channels.window_size (tuple[int]): The height and width of the window.num_heads (int): Number of attention heads.qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: Trueattn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0proj_drop (float, optional): Dropout ratio of output. Default: 0.0"""def __init__(self, dim, window_size, num_heads, qkv_bias=True, attn_drop=0., proj_drop=0.):super().__init__()self.dim = dimself.window_size = window_size # [Mh, Mw]self.num_heads = num_headshead_dim = dim // num_headsself.scale = head_dim ** -0.5# define a parameter table of relative position biasself.relative_position_bias_table = nn.Parameter(torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # [2*Mh-1 * 2*Mw-1, nH]# get pair-wise relative position index for each token inside the windowcoords_h = torch.arange(self.window_size[0])coords_w = torch.arange(self.window_size[1])coords = torch.stack(torch.meshgrid([coords_h, coords_w], indexing="ij")) # [2, Mh, Mw]coords_flatten = torch.flatten(coords, 1) # [2, Mh*Mw]# [2, Mh*Mw, 1] - [2, 1, Mh*Mw]relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # [2, Mh*Mw, Mh*Mw]relative_coords = relative_coords.permute(1, 2, 0).contiguous() # [Mh*Mw, Mh*Mw, 2]relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0relative_coords[:, :, 1] += self.window_size[1] - 1relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1relative_position_index = relative_coords.sum(-1) # [Mh*Mw, Mh*Mw]self.register_buffer("relative_position_index", relative_position_index)self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)self.attn_drop = nn.Dropout(attn_drop)self.proj = nn.Linear(dim, dim)self.proj_drop = nn.Dropout(proj_drop)nn.init.trunc_normal_(self.relative_position_bias_table, std=.02)self.softmax = nn.Softmax(dim=-1)def forward(self, x, mask: Optional[torch.Tensor] = None):"""Args:x: input features with shape of (num_windows*B, Mh*Mw, C)mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None"""# [batch_size*num_windows, Mh*Mw, total_embed_dim]B_, N, C = x.shape# qkv(): -> [batch_size*num_windows, Mh*Mw, 3 * total_embed_dim]# reshape: -> [batch_size*num_windows, Mh*Mw, 3, num_heads, embed_dim_per_head]# permute: -> [3, batch_size*num_windows, num_heads, Mh*Mw, embed_dim_per_head]qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)# [batch_size*num_windows, num_heads, Mh*Mw, embed_dim_per_head]q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple)# transpose: -> [batch_size*num_windows, num_heads, embed_dim_per_head, Mh*Mw]# @: multiply -> [batch_size*num_windows, num_heads, Mh*Mw, Mh*Mw]q = q * self.scaleattn = (q @ k.transpose(-2, -1))# relative_position_bias_table.view: [Mh*Mw*Mh*Mw,nH] -> [Mh*Mw,Mh*Mw,nH]relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1)relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # [nH, Mh*Mw, Mh*Mw]attn = attn + relative_position_bias.unsqueeze(0)if mask is not None:# mask: [nW, Mh*Mw, Mh*Mw]nW = mask.shape[0] # num_windows# attn.view: [batch_size, num_windows, num_heads, Mh*Mw, Mh*Mw]# mask.unsqueeze: [1, nW, 1, Mh*Mw, Mh*Mw]attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)attn = attn.view(-1, self.num_heads, N, N)attn = self.softmax(attn)else:attn = self.softmax(attn)attn = self.attn_drop(attn)# @: multiply -> [batch_size*num_windows, num_heads, Mh*Mw, embed_dim_per_head]# transpose: -> [batch_size*num_windows, Mh*Mw, num_heads, embed_dim_per_head]# reshape: -> [batch_size*num_windows, Mh*Mw, total_embed_dim]x = (attn @ v).transpose(1, 2).reshape(B_, N, C)x = self.proj(x)x = self.proj_drop(x)return xclass SwinTransformerBlock(nn.Module):r""" Swin Transformer Block.Args:dim (int): Number of input channels.num_heads (int): Number of attention heads.window_size (int): Window size.shift_size (int): Shift size for SW-MSA.mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: Truedrop (float, optional): Dropout rate. Default: 0.0attn_drop (float, optional): Attention dropout rate. Default: 0.0drop_path (float, optional): Stochastic depth rate. Default: 0.0act_layer (nn.Module, optional): Activation layer. Default: nn.GELUnorm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm"""def __init__(self, dim, num_heads, window_size=7, shift_size=0,mlp_ratio=4., qkv_bias=True, drop=0., attn_drop=0., drop_path=0.,act_layer=nn.GELU, norm_layer=nn.LayerNorm):super().__init__()self.dim = dimself.num_heads = num_headsself.window_size = window_sizeself.shift_size = shift_sizeself.mlp_ratio = mlp_ratioassert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"self.norm1 = norm_layer(dim)self.attn = WindowAttention(dim, window_size=(self.window_size, self.window_size), num_heads=num_heads, qkv_bias=qkv_bias,attn_drop=attn_drop, proj_drop=drop)self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()self.norm2 = norm_layer(dim)mlp_hidden_dim = int(dim * mlp_ratio)self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)def forward(self, x, attn_mask):H, W = self.H, self.WB, L, C = x.shapeassert L == H * W, "input feature has wrong size"shortcut = xx = self.norm1(x)x = x.view(B, H, W, C)# pad feature maps to multiples of window size# 把feature map给pad到window size的整数倍pad_l = pad_t = 0pad_r = (self.window_size - W % self.window_size) % self.window_sizepad_b = (self.window_size - H % self.window_size) % self.window_sizex = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b))_, Hp, Wp, _ = x.shape# cyclic shiftif self.shift_size > 0:shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))else:shifted_x = xattn_mask = None# partition windowsx_windows = window_partition(shifted_x, self.window_size) # [nW*B, Mh, Mw, C]x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # [nW*B, Mh*Mw, C]# W-MSA/SW-MSAattn_windows = self.attn(x_windows, mask=attn_mask) # [nW*B, Mh*Mw, C]# merge windowsattn_windows = attn_windows.view(-1, self.window_size, self.window_size, C) # [nW*B, Mh, Mw, C]shifted_x = window_reverse(attn_windows, self.window_size, Hp, Wp) # [B, H', W', C]# reverse cyclic shiftif self.shift_size > 0:x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))else:x = shifted_xif pad_r > 0 or pad_b > 0:# 把前面pad的数据移除掉x = x[:, :H, :W, :].contiguous()x = x.view(B, H * W, C)# FFNx = shortcut + self.drop_path(x)x = x + self.drop_path(self.mlp(self.norm2(x)))return xclass BasicLayer(nn.Module):"""A basic Swin Transformer layer for one stage.Args:dim (int): Number of input channels.depth (int): Number of blocks.num_heads (int): Number of attention heads.window_size (int): Local window size.mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: Truedrop (float, optional): Dropout rate. Default: 0.0attn_drop (float, optional): Attention dropout rate. Default: 0.0drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNormdownsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: Noneuse_checkpoint (bool): Whether to use checkpointing to save memory. Default: False."""def __init__(self, dim, depth, num_heads, window_size,mlp_ratio=4., qkv_bias=True, drop=0., attn_drop=0.,drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False):super().__init__()self.dim = dimself.depth = depthself.window_size = window_sizeself.use_checkpoint = use_checkpointself.shift_size = window_size // 2# build blocksself.blocks = nn.ModuleList([SwinTransformerBlock(dim=dim,num_heads=num_heads,window_size=window_size,shift_size=0 if (i % 2 == 0) else self.shift_size,mlp_ratio=mlp_ratio,qkv_bias=qkv_bias,drop=drop,attn_drop=attn_drop,drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,norm_layer=norm_layer)for i in range(depth)])# patch merging layerif downsample is not None:self.downsample = downsample(dim=dim, norm_layer=norm_layer)else:self.downsample = Nonedef create_mask(self, x, H, W):# calculate attention mask for SW-MSA# 保证Hp和Wp是window_size的整数倍Hp = int(np.ceil(H / self.window_size)) * self.window_sizeWp = int(np.ceil(W / self.window_size)) * self.window_size# 拥有和feature map一样的通道排列顺序,方便后续window_partitionimg_mask = torch.zeros((1, Hp, Wp, 1), device=x.device) # [1, Hp, Wp, 1]h_slices = (slice(0, -self.window_size),slice(-self.window_size, -self.shift_size),slice(-self.shift_size, None))w_slices = (slice(0, -self.window_size),slice(-self.window_size, -self.shift_size),slice(-self.shift_size, None))cnt = 0for h in h_slices:for w in w_slices:img_mask[:, h, w, :] = cntcnt += 1mask_windows = window_partition(img_mask, self.window_size) # [nW, Mh, Mw, 1]mask_windows = mask_windows.view(-1, self.window_size * self.window_size) # [nW, Mh*Mw]attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) # [nW, 1, Mh*Mw] - [nW, Mh*Mw, 1]# [nW, Mh*Mw, Mh*Mw]attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))return attn_maskdef forward(self, x, H, W):attn_mask = self.create_mask(x, H, W) # [nW, Mh*Mw, Mh*Mw]for blk in self.blocks:blk.H, blk.W = H, Wif not torch.jit.is_scripting() and self.use_checkpoint:x = checkpoint.checkpoint(blk, x, attn_mask)else:x = blk(x, attn_mask)if self.downsample is not None:x = self.downsample(x, H, W)H, W = (H + 1) // 2, (W + 1) // 2return x, H, Wclass SwinTransformer(nn.Module):r""" Swin TransformerA PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` -https://arxiv.org/pdf/2103.14030Args:patch_size (int | tuple(int)): Patch size. Default: 4in_chans (int): Number of input image channels. Default: 3num_classes (int): Number of classes for classification head. Default: 1000embed_dim (int): Patch embedding dimension. Default: 96depths (tuple(int)): Depth of each Swin Transformer layer.num_heads (tuple(int)): Number of attention heads in different layers.window_size (int): Window size. Default: 7mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: Truedrop_rate (float): Dropout rate. Default: 0attn_drop_rate (float): Attention dropout rate. Default: 0drop_path_rate (float): Stochastic depth rate. Default: 0.1norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.patch_norm (bool): If True, add normalization after patch embedding. Default: Trueuse_checkpoint (bool): Whether to use checkpointing to save memory. Default: False"""def __init__(self, patch_size=4, in_chans=3, num_classes=1000,embed_dim=96, depths=(2, 2, 6, 2), num_heads=(3, 6, 12, 24),window_size=7, mlp_ratio=4., qkv_bias=True,drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1,norm_layer=nn.LayerNorm, patch_norm=True,use_checkpoint=False, **kwargs):super().__init__()self.num_classes = num_classesself.num_layers = len(depths)self.embed_dim = embed_dimself.patch_norm = patch_norm# stage4输出特征矩阵的channelsself.num_features = int(embed_dim * 2 ** (self.num_layers - 1))self.mlp_ratio = mlp_ratio# split image into non-overlapping patchesself.patch_embed = PatchEmbed(patch_size=patch_size, in_c=in_chans, embed_dim=embed_dim,norm_layer=norm_layer if self.patch_norm else None)self.pos_drop = nn.Dropout(p=drop_rate)# stochastic depthdpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule# build layersself.layers = nn.ModuleList()for i_layer in range(self.num_layers):# 注意这里构建的stage和论文图中有些差异# 这里的stage不包含该stage的patch_merging层,包含的是下个stage的layers = BasicLayer(dim=int(embed_dim * 2 ** i_layer),depth=depths[i_layer],num_heads=num_heads[i_layer],window_size=window_size,mlp_ratio=self.mlp_ratio,qkv_bias=qkv_bias,drop=drop_rate,attn_drop=attn_drop_rate,drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],norm_layer=norm_layer,downsample=PatchMerging if (i_layer < self.num_layers - 1) else None,use_checkpoint=use_checkpoint)self.layers.append(layers)self.norm = norm_layer(self.num_features)self.avgpool = nn.AdaptiveAvgPool1d(1)self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()self.apply(self._init_weights)def _init_weights(self, m):if isinstance(m, nn.Linear):nn.init.trunc_normal_(m.weight, std=.02)if isinstance(m, nn.Linear) and m.bias is not None:nn.init.constant_(m.bias, 0)elif isinstance(m, nn.LayerNorm):nn.init.constant_(m.bias, 0)nn.init.constant_(m.weight, 1.0)def forward(self, x):# x: [B, L, C]x, H, W = self.patch_embed(x)x = self.pos_drop(x)for layer in self.layers:x, H, W = layer(x, H, W)x = self.norm(x) # [B, L, C]x = self.avgpool(x.transpose(1, 2)) # [B, C, 1]x = torch.flatten(x, 1)x = self.head(x)return xdef swin_tiny_patch4_window7_224(num_classes: int = 1000, **kwargs):# trained ImageNet-1K# https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_tiny_patch4_window7_224.pthmodel = SwinTransformer(in_chans=3,patch_size=4,window_size=7,embed_dim=96,depths=(2, 2, 6, 2),num_heads=(3, 6, 12, 24),num_classes=num_classes,**kwargs)return model
引入预训练模型之后,进行花类数据集(上篇博客提到)进行微调,代码如下:
import os
import argparseimport torch
import torch.optim as optim
from torch.utils.tensorboard import SummaryWriter
from torchvision import transformsfrom my_dataset import MyDataSet
from model import swin_tiny_patch4_window7_224 as create_model
from utils import read_split_data, train_one_epoch, evaluatedef main(args):device = torch.device(args.device if torch.cuda.is_available() else "cpu")if os.path.exists("../weights") is False:os.makedirs("../weights")tb_writer = SummaryWriter()train_images_path, train_images_label, val_images_path, val_images_label = read_split_data(args.data_path)img_size = 224data_transform = {"train": transforms.Compose([transforms.RandomResizedCrop(img_size),transforms.RandomHorizontalFlip(),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]),"val": transforms.Compose([transforms.Resize(int(img_size * 1.143)),transforms.CenterCrop(img_size),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])}# 实例化训练数据集train_dataset = MyDataSet(images_path=train_images_path,images_class=train_images_label,transform=data_transform["train"])# 实例化验证数据集val_dataset = MyDataSet(images_path=val_images_path,images_class=val_images_label,transform=data_transform["val"])batch_size = args.batch_sizenw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8]) # number of workersprint('Using {} dataloader workers every process'.format(nw))train_loader = torch.utils.data.DataLoader(train_dataset,batch_size=batch_size,shuffle=True,pin_memory=True,num_workers=nw,collate_fn=train_dataset.collate_fn)val_loader = torch.utils.data.DataLoader(val_dataset,batch_size=batch_size,shuffle=False,pin_memory=True,num_workers=nw,collate_fn=val_dataset.collate_fn)model = create_model(num_classes=args.num_classes).to(device)if args.weights != "":assert os.path.exists(args.weights), "weights file: '{}' not exist.".format(args.weights)weights_dict = torch.load(args.weights, map_location=device)["model"]# 删除有关分类类别的权重for k in list(weights_dict.keys()):if "head" in k:del weights_dict[k]print(model.load_state_dict(weights_dict, strict=False))if args.freeze_layers:for name, para in model.named_parameters():# 除head外,其他权重全部冻结if "head" not in name:para.requires_grad_(False)else:print("training {}".format(name))pg = [p for p in model.parameters() if p.requires_grad]optimizer = optim.AdamW(pg, lr=args.lr, weight_decay=5E-2)for epoch in range(args.epochs):# traintrain_loss, train_acc = train_one_epoch(model=model,optimizer=optimizer,data_loader=train_loader,device=device,epoch=epoch)# validateval_loss, val_acc = evaluate(model=model,data_loader=val_loader,device=device,epoch=epoch)tags = ["train_loss", "train_acc", "val_loss", "val_acc", "learning_rate"]tb_writer.add_scalar(tags[0], train_loss, epoch)tb_writer.add_scalar(tags[1], train_acc, epoch)tb_writer.add_scalar(tags[2], val_loss, epoch)tb_writer.add_scalar(tags[3], val_acc, epoch)tb_writer.add_scalar(tags[4], optimizer.param_groups[0]["lr"], epoch)torch.save(model.state_dict(), "../weights/model-{}.pth".format(epoch))if __name__ == '__main__':parser = argparse.ArgumentParser()parser.add_argument('--num_classes', type=int, default=5)parser.add_argument('--epochs', type=int, default=10)parser.add_argument('--batch-size', type=int, default=8)parser.add_argument('--lr', type=float, default=0.0001)# 数据集所在根目录# https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgzparser.add_argument('--data-path', type=str,default="../data/flower_photos")# 预训练权重路径,如果不想载入就设置为空字符parser.add_argument('--weights', type=str, default='../weights/swin_tiny_patch4_window7_224.pth',help='initial weights path')# 是否冻结权重parser.add_argument('--freeze-layers', type=bool, default=False)parser.add_argument('--device', default='cuda:0', help='device id (i.e. 0 or 0,1 or cpu)')opt = parser.parse_args()main(opt)
模型预测代码:
import os
import jsonimport torch
from PIL import Image
from torchvision import transforms
import matplotlib.pyplot as pltfrom model import swin_tiny_patch4_window7_224 as create_modeldef main():device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")img_size = 224data_transform = transforms.Compose([transforms.Resize(int(img_size * 1.14)),transforms.CenterCrop(img_size),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])# load imageimg_path = "../data/Image/flower.png"assert os.path.exists(img_path), "file: '{}' dose not exist.".format(img_path)img = Image.open(img_path)plt.imshow(img)plt.show()img2 = img# [N, C, H, W]img = img.convert('RGB')img = data_transform(img)# expand batch dimensionimg = torch.unsqueeze(img, dim=0)# read class_indictjson_path = 'class_indices.json'assert os.path.exists(json_path), "file: '{}' dose not exist.".format(json_path)with open(json_path, "r") as f:class_indict = json.load(f)# create modelmodel = create_model(num_classes=5).to(device)# load model weightsmodel_weight_path = "../weights/model-9.pth"model.load_state_dict(torch.load(model_weight_path, map_location=device))model.eval()with torch.no_grad():# predict classoutput = torch.squeeze(model(img.to(device))).cpu()predict = torch.softmax(output, dim=0)predict_cla = torch.argmax(predict).numpy()print_res = "class: {} prob: {:.3}".format(class_indict[str(predict_cla)], predict[predict_cla].numpy())plt.title(print_res)for i in range(len(predict)):print("class: {:10} prob: {:.3}".format(class_indict[str(i)], predict[i].numpy()))plt.imshow(img2)plt.show()if __name__ == '__main__':main()
代码预测结果如下图所示:
大约99.8%的可能性为玫瑰类别
总结
Swin Transformer模型通过创新的窗口机制在保持Transformer模型优势的同时,显著降低了计算成本,使其在计算机视觉任务中表现出色。模型的层次化特征提取和多尺度感知能力使其在图像分类、目标检测和语义分割等任务中取得了优异的性能。通过实际代码的展示和微调应用,预测结果的高准确率进一步证明了模型的强大能力,特别是在处理高分辨率图像时的效率和准确性。