SwinTransformer的相对位置索引的原理以及源码分析

server/2024/10/21 5:43:37/

文章目录

  • 1. 理论分析
  • 2. 完整代码


引用:参考博客链接


1. 理论分析

根据论文中提供的公式可知是在 Q Q Q K K K进行匹配并除以 d \sqrt d d 后加上了相对位置偏执 B B B

A t t e n t i o n ( Q , K , V ) = S o f t m a x ( Q K T d + B ) V \begin{aligned} &Attention(Q,K,V) = Softmax(\frac{QK^T}{\sqrt d}+B)V \end{aligned} Attention(Q,K,V)=Softmax(d QKT+B)V

    如下图,假设输入的feature map高宽都为2,那么首先我们可以构建出每个像素的绝对位置(左下方的矩阵),对于每个像素的绝对位置是使用行号和列号表示的。比如蓝色的像素对应的是第0行第0列所以绝对位置索引是(0,0),接下来再看看相对位置索引。首先看下蓝色的像素,在蓝色像素使用q与所有像素k进行匹配过程中,是以蓝色像素为参考点。然后用蓝色像素的绝对位置索引与其他位置索引进行相减,就得到其他位置相对蓝色像素的相对位置索引。例如黄色像素的绝对位置索引是(0,1),则它相对蓝色像素的相对位置索引为 ( 0 , 0 ) − ( 0 , 1 ) = ( 0 , − 1 ) (0,0)−(0,1)=(0,−1) (0,0)(0,1)=(0,1) 。那么同理可以得到其他位置相对蓝色像素的相对位置索引矩阵。同样,也能得到相对黄色,红色以及绿色像素的相对位置索引矩阵。接下来将每个相对位置索引矩阵按行展平,并拼接在一起可以得到下面的4x4矩阵 。

在这里插入图片描述

对应源代码为:

import torch
import torch.nn as  nn
from timm.models.layers import trunc_normal_window_size = [2,2]
num_heads = 3# get pair-wise relative position index for each token inside the window
coords_h = torch.arange(window_size[0])
coords_w = torch.arange(window_size[1])
coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 绝对位置索引coords_flatten = torch.flatten(coords, 1)  # 2, Wh*Ww
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]  # 2, Wh*Ww, Wh*Ww
relative_coords = relative_coords.permute(1, 2, 0).contiguous()  # Wh*Ww, Wh*Ww, 2

    请注意,这里描述的一直是相对位置索引,并不是相对位置偏执参数。因为后面我们会根据相对位置索引去取对应的参数。比如说黄色像素是在蓝色像素的右边,所以相对蓝色像素的相对位置索引为(0,−1)。绿色像素是在红色像素的右边,所以相对红色像素的相对位置索引为(0,−1)。可以发现这两者的相对位置索引都是(0,−1),所以他们使用的相对位置偏执参数都是一样的。

    其实讲到这基本已经讲完了,但在源码中作者为了方便把二维索引给转成了一维索引。具体这么转的呢,有人肯定想到,简单啊直接把行、列索引相加不就变一维了吗?比如上面的相对位置索引中有(0,−1)和(−1,0)在二维的相对位置索引中明显是代表不同的位置,但如果简单相加都等于-1那不就出问题了吗?接下来我们看看源码中是怎么做的。首先在原始的相对位置索引上加上 ( M − 1 ) (M-1) (M1) ( M M M为窗口的大小,在本示例中 M M M=2),加上之后索引中就不会有负数了。
    
在这里插入图片描述

对应源代码为:

relative_coords[:, :, 0] += window_size[0] - 1  # shift to start from 0
relative_coords[:, :, 1] += window_size[1] - 1

    
接着将所有的行标都乘上2M-1。

在这里插入图片描述

    
对应源代码为:

relative_coords[:, :, 0] *= 2 * window_size[1] - 1

最后将行标和列标进行相加。这样即保证了相对位置关系,而且不会出现上述0 + ( − 1 ) = ( − 1 ) + 0 0+(-1)=(-1)+00+(−1)=(−1)+0的问题了,是不是很神奇。

在这里插入图片描述

relative_position_index = relative_coords.sum(-1)

    刚刚上面也说了,之前计算的是相对位置索引,并不是相对位置偏执参数。真正使用到的可训练参数 B ^ \hat{B} B^ 是保存在relative position bias table表里的,这个表的长度是等于 ( 2 M − 1 ) × ( 2 M − 1 ) (2M−1)×(2M−1) (2M1)×(2M1)的。那么上述公式中的相对位置偏执参数 B B B是根据上面的相对位置索引表根据查relative position bias table表得到的,如下图所示。

在这里插入图片描述

对应源代码为:

'''
relative_position_bias_table:其shape=((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads))
'''
relative_position_bias_table = nn.Parameter(torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads))
trunc_normal_(relative_position_bias_table, std=.02)'''
index:虽然shape=(window_size[0] * window_size[1], window_size[0] * window_size[1]),但是只有(2 * window_size[0] - 1) * (2 * window_size[1] - 1)个不同的元素。作为索引,正好能一一对应relative_position_bias_table中的元素
'''
index = relative_position_index.view(-1)
relative_position_bias = relative_position_bias_table[index] # index的每一个不同的元素对应relative_position_bias_table中一个值relative_position_bias = relative_position_bias.view(window_size[0] * window_size[1], window_size[0] * window_size[1], -1)
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()  # nH, Wh*Ww, Wh*Ww

2. 完整代码

import torch
import torch.nn as  nn
from timm.models.layers import trunc_normal_window_size = [2,2]
num_heads = 3# get pair-wise relative position index for each token inside the window
coords_h = torch.arange(window_size[0])
coords_w = torch.arange(window_size[1])
coords = torch.stack(torch.meshgrid([coords_h, coords_w]))  # # 绝对位置索引 # 2, Wh, Ww
coords_flatten = torch.flatten(coords, 1)  # 2, Wh*Ww
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]  # 2, Wh*Ww, Wh*Ww
relative_coords_temp = relative_coords.numpy()
relative_coords = relative_coords.permute(1, 2, 0).contiguous()  # Wh*Ww, Wh*Ww, 2
relative_coords[:, :, 0] += window_size[0] - 1  # shift to start from 0
relative_coords[:, :, 1] += window_size[1] - 1relative_coords[:, :, 0] *= 2 * window_size[1] - 1relative_position_index = relative_coords.sum(-1)relative_position_bias_table = nn.Parameter(torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads))
trunc_normal_(relative_position_bias_table, std=.02)print('relative_position_bias_table:',relative_position_bias_table.shape)
print(relative_position_index.shape)index = relative_position_index.view(-1)
print('index:',index.shape)
relative_position_bias = relative_position_bias_table[index]
print('relative_position_bias_Noreshape:',relative_position_bias.shape)relative_position_bias = relative_position_bias.view(window_size[0] * window_size[1], window_size[0] * window_size[1], -1)relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()  # nH, Wh*Ww, Wh*Wwprint('relative_position_bias:',relative_position_bias.shape)

http://www.ppmy.cn/server/55881.html

相关文章

TQ15EG开发板教程:MPSOC创建fmcomms8工程

链接:https://pan.baidu.com/s/1jbuYs9alP2SaqnV5fpNgyg 提取码:r00c 本例程需要实现在hdl加no-OS系统中,通过修改fmcomms8/zcu102项目,实现在MPSOC两个fmc口上运行fmcomms8项目。 目录 1 下载文件与切换版本 2 编译fmcomms8项…

软考中级数据库系统工程师备考经验分享

前几天软考成绩出了,赶紧查询了一下发现自己顺利通过啦(上午63,下午67,开心),因此本文记录一下我的备考经验分享给大家。因为工作中项目管理类的知识没有系统学习过,本来想直接报名软考高级证书…

python读取写入txt文本文件

读取 txt 文件 def read_txt_file(file_path):"""读取文本文件的内容:param file_path: 文本文件的路径:return: 文件内容"""try:with open(file_path, r, encodingutf-8) as file:content file.read()return contentexcept FileNotFoundError…

机器学习原理之 -- XGboost原理详解

XGBoost(eXtreme Gradient Boosting)是近年来在数据科学和机器学习领域中广受欢迎的集成学习算法。它在多个数据科学竞赛中表现出色,被广泛应用于各种机器学习任务。本文将详细介绍XGBoost的由来、基本原理、算法细节、优缺点及应用场景。 X…

Shell Expect自动化交互(示例)

Shell Expect自动化交互 日常linux运维时,经常需要远程登录到服务器,登录过程中需要交互的过程,可能需要输入yes/no等信息,所以就用到expect来实现交互。 关键语法 ❶[#!/usr/bin/expect] 这一行告诉操…

苍穹外卖--sky-take-out(四)10-12

苍穹外卖--sky-take-out(一) 苍穹外卖--sky-take-out(一)-CSDN博客​编辑https://blog.csdn.net/kussm_/article/details/138614737?spm1001.2014.3001.5501https://blog.csdn.net/kussm_/article/details/138614737?spm1001.2…

【Gin】项目搭建 一

环境准备 首先确保自己电脑安装了Golang 开始项目 1、初始化项目 mkdir gin-hello; # 创建文件夹 cd gin-hello; # 需要到刚创建的文件夹里操作 go mod init goserver; # 初始化项目,项目名称:goserver go get -u github.com/gin-gonic/gin; # 下载…

A Threat Actors 出售 18 万名 Shopify 用户信息

BreachForums 论坛成员最近发布了涉及 Shopify 的重大数据泄露事件。 据报道,属于近 180,000 名用户的敏感数据遭到泄露。 Shopify Inc. 是一家总部位于安大略省渥太华的加拿大公司。 开发和营销同名电子商务平台、Shopify POS 销售点系统以及专用于企业的营销工…