swin transformer中相对位置编码解析

devtools/2025/1/21 1:34:01/

在论文中,作者发现相对位置编码的效果会更好一些。

代码的实现为:

 # get pair-wise relative position index for each token inside the windowcoords_h = torch.arange(self.window_size[0])coords_w = torch.arange(self.window_size[1])coords = torch.stack(torch.meshgrid([coords_h, coords_w]))  # 2, Wh, Wwcoords_flatten = torch.flatten(coords, 1)  # 2, Wh*Wwrelative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]  # 2, Wh*Ww, Wh*Wwrelative_coords = relative_coords.permute(1, 2, 0).contiguous()  # Wh*Ww, Wh*Ww, 2relative_coords[:, :, 0] += self.window_size[0] - 1  # shift to start from 0relative_coords[:, :, 1] += self.window_size[1] - 1relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1relative_position_index = relative_coords.sum(-1)  # Wh*Ww, Wh*Wwself.register_buffer("relative_position_index", relative_position_index)

在forward中的计算为:

relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1)  # Wh*Ww,Wh*Ww,nHrelative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()  # nH, Wh*Ww, Wh*Wwattn = attn + relative_position_bias.unsqueeze(0)

完整的class实现为:

class WindowAttention(nn.Module):r""" Window based multi-head self attention (W-MSA) module with relative position bias.It supports both of shifted and non-shifted window.Args:dim (int): Number of input channels.window_size (tuple[int]): The height and width of the window.num_heads (int): Number of attention heads.qkv_bias (bool, optional):  If True, add a learnable bias to query, key, value. Default: Trueqk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if setattn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0proj_drop (float, optional): Dropout ratio of output. Default: 0.0"""def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.):super().__init__()self.dim = dimself.window_size = window_size  # Wh, Wwself.num_heads = num_headshead_dim = dim // num_headsself.scale = qk_scale or head_dim ** -0.5# define a parameter table of relative position biasself.relative_position_bias_table = nn.Parameter(torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads))  # 2*Wh-1 * 2*Ww-1, nH# get pair-wise relative position index for each token inside the windowcoords_h = torch.arange(self.window_size[0])coords_w = torch.arange(self.window_size[1])coords = torch.stack(torch.meshgrid([coords_h, coords_w]))  # 2, Wh, Wwcoords_flatten = torch.flatten(coords, 1)  # 2, Wh*Wwrelative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]  # 2, Wh*Ww, Wh*Wwrelative_coords = relative_coords.permute(1, 2, 0).contiguous()  # Wh*Ww, Wh*Ww, 2relative_coords[:, :, 0] += self.window_size[0] - 1  # shift to start from 0relative_coords[:, :, 1] += self.window_size[1] - 1relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1relative_position_index = relative_coords.sum(-1)  # Wh*Ww, Wh*Wwself.register_buffer("relative_position_index", relative_position_index)self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)self.attn_drop = nn.Dropout(attn_drop)self.proj = nn.Linear(dim, dim)self.proj_drop = nn.Dropout(proj_drop)trunc_normal_(self.relative_position_bias_table, std=.02)self.softmax = nn.Softmax(dim=-1)def forward(self, x, mask=None):"""Args:x: input features with shape of (num_windows*B, N, C)mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None"""B_, N, C = x.shapeqkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)q, k, v = qkv[0], qkv[1], qkv[2]  # make torchscript happy (cannot use tensor as tuple)q = q * self.scaleattn = (q @ k.transpose(-2, -1))relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1)  # Wh*Ww,Wh*Ww,nHrelative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()  # nH, Wh*Ww, Wh*Wwattn = attn + relative_position_bias.unsqueeze(0)if mask is not None:nW = mask.shape[0]attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)attn = attn.view(-1, self.num_heads, N, N)attn = self.softmax(attn)else:attn = self.softmax(attn)attn = self.attn_drop(attn)x = (attn @ v).transpose(1, 2).reshape(B_, N, C)x = self.proj(x)x = self.proj_drop(x)return xdef extra_repr(self) -> str:return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}'def flops(self, N):# calculate flops for 1 window with token length of Nflops = 0# qkv = self.qkv(x)flops += N * self.dim * 3 * self.dim# attn = (q @ k.transpose(-2, -1))flops += self.num_heads * N * (self.dim // self.num_heads) * N#  x = (attn @ v)flops += self.num_heads * N * N * (self.dim // self.num_heads)# x = self.proj(x)flops += N * self.dim * self.dimreturn flops

这个计算过程,我们用windows size = 2为例来看一下计算过程。图片取自链接transformer入门 论文阅读(4) Swin Transformer | shifted window,relative position bias详解

通过这个图示,是比较容易理解相对位置编码的计算过程的,下面我们在实际的代码上跑一下,看看实际的数值变化,以及在forward中的计算过程。

我们使用swin transformer的imagenet image classification任务为例,逐步来解释每行代码。在这个任务中,window_size = 7。

1. 定义relative_position_bias_table

        # define a parameter table of relative position biasself.relative_position_bias_table = nn.Parameter(torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads))  # 2*Wh-1 * 2*Ww-1, nH

这个定义了relative_position_bias_table,对应上图示中的relative_position_bias_table。

这个值的初始化值为shape为[169, 4]的全0值。169是(2 * window_size[0] - 1) * (2 * window_size[1] - 1),也就是(27-1)(2*7-1) = 169.

然后将其初始化为一个服从截断正态分布的随机值,标准差为 0.02。截断正态分布的范围通常限制在均值两侧一定范围内,避免生成过大或过小的值。

trunc_normal_(self.relative_position_bias_table, std=.02)

经过截断和正态分布的初始化后,self.relative_position_bias_table的shape是[169, 4],说明有169个相对位置可以索引,有4个头。

self.relative_position_bias_table的值是:

self.relative_position_bias_table = Parameter containing:
tensor([[ 1.3016e-02,  1.2930e-02,  1.5971e-02, -2.9950e-02],[-1.2150e-02, -1.8186e-02,  1.8201e-02, -2.9683e-02],[-1.4085e-03, -9.6917e-03,  1.7187e-02, -2.1197e-02],[ 4.5870e-04, -2.5759e-02,  1.0428e-02,  7.8378e-03],...[-4.6959e-03, -1.1017e-02,  1.3361e-02,  7.7851e-03],[ 1.7211e-02, -9.5882e-03,  6.2699e-02,  7.8999e-03],[ 1.5927e-02, -5.5237e-02,  1.6605e-02, -1.4664e-02],[-2.6448e-02,  8.7442e-03,  5.1785e-03,  3.0192e-02]],requires_grad=True)

这个shape的原因是:

相对位置的意义为:对于窗口注意力机制,每个 token 的位置都是相对于窗口内其他 token 定义的。窗口大小为 (Wh, Ww),窗口内总共有 Wh * Ww 个 token。相对位置表示的是两个 token 在垂直方向(高度)和水平方向(宽度)上的偏移量。

例如:

  • 一个窗口大小为 (3, 3)
    • 水平相对位置范围是:[-2, -1, 0, 1, 2](总共 2 * Wh - 1 = 5)。
    • 垂直相对位置范围是:[-2, -1, 0, 1, 2](总共 2 * Ww - 1 = 5)。

这意味着在 2D 平面中,窗口内的 token 之间的相对位置总共有:

(2 * Wh - 1) * (2 * Ww - 1)

例如,(3, 3) 窗口有 5 * 5 = 25 种可能的相对位置。


self.relative_position_bias_table` 的形状是(2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)

  • 第一维度(2 * Wh - 1) * (2 * Ww - 1),表示所有可能的相对位置。
  • 第二维度num_heads,因为每个注意力头都会有一个单独的偏置。

例如:

  • 窗口大小 (3, 3),有 25 种可能的相对位置,num_heads=8
  • relative_position_bias_table 的形状为 (25, 8)

这一表将存储每个相对位置对于每个注意力头的偏置值。

2.获取绝对位置

# get pair-wise relative position index for each token inside the windowcoords_h = torch.arange(self.window_size[0])coords_w = torch.arange(self.window_size[1])coords = torch.stack(torch.meshgrid([coords_h, coords_w]))  # 2, Wh, Wwcoords_flatten = torch.flatten(coords, 1)  # 2, Wh*Ww

在这里,

coords_h = tensor([0, 1, 2, 3, 4, 5, 6]), coords_w = tensor([0, 1, 2, 3, 4, 5, 6])

经过meshgrid和stack计算之后,

coords = tensor([[[0, 0, 0, 0, 0, 0, 0],[1, 1, 1, 1, 1, 1, 1],[2, 2, 2, 2, 2, 2, 2],[3, 3, 3, 3, 3, 3, 3],[4, 4, 4, 4, 4, 4, 4],[5, 5, 5, 5, 5, 5, 5],[6, 6, 6, 6, 6, 6, 6]],[[0, 1, 2, 3, 4, 5, 6],[0, 1, 2, 3, 4, 5, 6],[0, 1, 2, 3, 4, 5, 6],[0, 1, 2, 3, 4, 5, 6],[0, 1, 2, 3, 4, 5, 6],[0, 1, 2, 3, 4, 5, 6],[0, 1, 2, 3, 4, 5, 6]]])

用图示表示就是

然后经过flatten计算

coords_flatten = tensor(
[[0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3,3, 3, 3, 3, 4, 4, 4, 4, 4, 4, 4, 5, 5, 5, 5, 5, 5, 5, 6, 6, 6, 6, 6, 6,6],
[0, 1, 2, 3, 4, 5, 6, 0, 1, 2, 3, 4, 5, 6, 0, 1, 2, 3, 4, 5, 6, 0, 1, 2,3, 4, 5, 6, 0, 1, 2, 3, 4, 5, 6, 0, 1, 2, 3, 4, 5, 6, 0, 1, 2, 3, 4, 5, 6]])

coords_flatten.shape = torch.Size([2, 49])

展平后代表的是相同的意义,只是排列形式发生了变化。

3.获取相对位置索引(通过绝对位置相减)

relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]  # 2, Wh*Ww, Wh*Ww
relative_coords = relative_coords.permute(1, 2, 0).contiguous()  # Wh*Ww, Wh*Ww, 2

coords_flatten[:, :, None]是shape为[2, 49, 1]的tensor。None在最后添加了一维。

coords_flatten[:, :, None] = tensor([[[0],[0],[0],[0],[0],[0],[0],[1],[1],[1],[1],[1],[1],[1],[2],[2],[2],[2],[2],[2],[2],[3],[3],[3],[3],[3],[3],[3],[4],[4],[4],[4],[4],[4],[4],[5],[5],[5],[5],[5],[5],[5],[6],[6],[6],[6],[6],[6],[6]],[[0],[1],[2],[3],[4],[5],[6],[0],[1],[2],[3],[4],[5],[6],[0],[1],[2],[3],[4],[5],[6],[0],[1],[2],[3],[4],[5],[6],[0],[1],[2],[3],[4],[5],[6],[0],[1],[2],[3],[4],[5],[6],[0],[1],[2],[3],[4],[5],[6]]])

coords_flatten[:, None, :]是shape为[2, 1, 49]的tensor,None在中间添加了一维。

coords_flatten[:, None, :] = tensor(
[[[0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 3, 3,3, 3, 3, 3, 3, 4, 4, 4, 4, 4, 4, 4, 5, 5, 5, 5, 5, 5, 5, 6, 6, 6, 6,6, 6, 6]],
[[0, 1, 2, 3, 4, 5, 6, 0, 1, 2, 3, 4, 5, 6, 0, 1, 2, 3, 4, 5, 6, 0, 1,2, 3, 4, 5, 6, 0, 1, 2, 3, 4, 5, 6, 0, 1, 2, 3, 4, 5, 6, 0, 1, 2, 3,4, 5, 6]]])

coords_flatten[:, None, :]与coords_flatten在数值上是一致的,只是多添加了一维。

coords_flatten = tensor(
[[0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3,3, 3, 3, 3, 4, 4, 4, 4, 4, 4, 4, 5, 5, 5, 5, 5, 5, 5, 6, 6, 6, 6, 6, 6,6],
[0, 1, 2, 3, 4, 5, 6, 0, 1, 2, 3, 4, 5, 6, 0, 1, 2, 3, 4, 5, 6, 0, 1, 2,3, 4, 5, 6, 0, 1, 2, 3, 4, 5, 6, 0, 1, 2, 3, 4, 5, 6, 0, 1, 2, 3, 4, 5, 6]])

coords_flatten[:, :, None]的shape是[2, 49, 1],coords_flatten[:, None, :]的shape是[2, 1, 49],通过广播机制进行相减,得到相对坐标relative_coords,relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]

relative_coords = 
tensor([[[ 0,  0,  0,  0,  0,  0,  0, -1, -1, -1, -1, -1, -1, -1, -2, -2, -2,-2, -2, -2, -2, -3, -3, -3, -3, -3, -3, -3, -4, -4, -4, -4, -4, -4,-4, -5, -5, -5, -5, -5, -5, -5, -6, -6, -6, -6, -6, -6, -6],[ 0,  0,  0,  0,  0,  0,  0, -1, -1, -1, -1, -1, -1, -1, -2, -2, -2,-2, -2, -2, -2, -3, -3, -3, -3, -3, -3, -3, -4, -4, -4, -4, -4, -4,-4, -5, -5, -5, -5, -5, -5, -5, -6, -6, -6, -6, -6, -6, -6],[ 0,  0,  0,  0,  0,  0,  0, -1, -1, -1, -1, -1, -1, -1, -2, -2, -2,-2, -2, -2, -2, -3, -3, -3, -3, -3, -3, -3, -4, -4, -4, -4, -4, -4,-4, -5, -5, -5, -5, -5, -5, -5, -6, -6, -6, -6, -6, -6, -6],[ 0,  0,  0,  0,  0,  0,  0, -1, -1, -1, -1, -1, -1, -1, -2, -2, -2,-2, -2, -2, -2, -3, -3, -3, -3, -3, -3, -3, -4, -4, -4, -4, -4, -4,-4, -5, -5, -5, -5, -5, -5, -5, -6, -6, -6, -6, -6, -6, -6],...4,  4,  4,  4,  3,  3,  3,  3,  3,  3,  3,  2,  2,  2,  2,  2,  2,2,  1,  1,  1,  1,  1,  1,  1,  0,  0,  0,  0,  0,  0,  0],[ 6,  6,  6,  6,  6,  6,  6,  5,  5,  5,  5,  5,  5,  5,  4,  4,  4,4,  4,  4,  4,  3,  3,  3,  3,  3,  3,  3,  2,  2,  2,  2,  2,  2,2,  1,  1,  1,  1,  1,  1,  1,  0,  0,  0,  0,  0,  0,  0]],[[ 0, -1, -2, -3, -4, -5, -6,  0, -1, -2, -3, -4, -5, -6,  0, -1, -2,-3, -4, -5, -6,  0, -1, -2, -3, -4, -5, -6,  0, -1, -2, -3, -4, -5,-6,  0, -1, -2, -3, -4, -5, -6,  0, -1, -2, -3, -4, -5, -6],[ 1,  0, -1, -2, -3, -4, -5,  1,  0, -1, -2, -3, -4, -5,  1,  0, -1,-2, -3, -4, -5,  1,  0, -1, -2, -3, -4, -5,  1,  0, -1, -2, -3, -4,...[ 5,  4,  3,  2,  1,  0, -1,  5,  4,  3,  2,  1,  0, -1,  5,  4,  3,2,  1,  0, -1,  5,  4,  3,  2,  1,  0, -1,  5,  4,  3,  2,  1,  0,-1,  5,  4,  3,  2,  1,  0, -1,  5,  4,  3,  2,  1,  0, -1],[ 6,  5,  4,  3,  2,  1,  0,  6,  5,  4,  3,  2,  1,  0,  6,  5,  4,3,  2,  1,  0,  6,  5,  4,  3,  2,  1,  0,  6,  5,  4,  3,  2,  1,0,  6,  5,  4,  3,  2,  1,  0,  6,  5,  4,  3,  2,  1,  0]]])

relative_coords的shape是torch.Size([2, 49, 49])。

针对relative_coords的第[2,0,49]的数值,用图示所示就是下图所示,以蓝色点为参考点,其他点相对参考点的距离。

然后以黄色点为参考点,其他点相对参考点的距离。

剩下的依次类推。

整体就是知乎文章图中这部分

的扩展。

4. 将shape为[2,49,49]relative_coords的49个不同参考点的相对距离值拉成一维

relative_coords = relative_coords.permute(1, 2, 0).contiguous()  # Wh*Ww, Wh*Ww, 2

此操作对应示例中的

经过处理后,relative_coords的shape是[49, 49, 2]。他的值可以用下面的图示表示:

即将以每一个参考点来说,其他位置的点的相对距离,都放在一行上。relative_coords的值是:

tensor([[[ 0,  0],[ 0, -1],[ 0, -2],[ 0, -3],[ 0, -4],[ 0, -5],[ 0, -6],[-1,  0],[-1, -1],[-1, -2],[-1, -3],[-1, -4],[-1, -5],[-1, -6],[-2,  0],[-2, -1],[-2, -2],[-2, -3],[-2, -4],[-2, -5],[-2, -6],[-3,  0],[-3, -1],[-3, -2],[-3, -3],[-3, -4],[-3, -5],[-3, -6],[-4,  0],[-4, -1],[-4, -2],[-4, -3],[-4, -4],[-4, -5],[-4, -6],[-5,  0],[-5, -1],[-5, -2],[-5, -3],[-5, -4],[-5, -5],[-5, -6],[-6,  0],[-6, -1],[-6, -2],[-6, -3],[-6, -4],[-6, -5],[-6, -6]],[[ 0,  1],[ 0,  0],...[ 0,  1],[ 0,  0],[ 0, -1]],[[ 6,  6],[ 6,  5],[ 6,  4],[ 6,  3],[ 6,  2],[ 6,  1],[ 6,  0],[ 5,  6],[ 5,  5],[ 5,  4],[ 5,  3],[ 5,  2],[ 5,  1],[ 5,  0],[ 4,  6],[ 4,  5],[ 4,  4],[ 4,  3],[ 4,  2],[ 4,  1],[ 4,  0],[ 3,  6],[ 3,  5],[ 3,  4],[ 3,  3],[ 3,  2],[ 3,  1],[ 3,  0],[ 2,  6],[ 2,  5],[ 2,  4],[ 2,  3],[ 2,  2],[ 2,  1],[ 2,  0],[ 1,  6],[ 1,  5],[ 1,  4],[ 1,  3],[ 1,  2],[ 1,  1],[ 1,  0],[ 0,  6],[ 0,  5],[ 0,  4],[ 0,  3],[ 0,  2],[ 0,  1],[ 0,  0]]])

5.对横纵坐标值进行数值处理

relative_coords[:, :, 0] += self.window_size[0] - 1  # shift to start from 0
relative_coords[:, :, 1] += self.window_size[1] - 1
relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1

relative_coords[:, :, 0]代表的是所有的横坐标,relative_coords[:, :, 1]代表的是所有的纵坐标。

relative_coords[:, :, 0]和relative_coords[:, :, 1]的shape都是[49, 49]。

relative_coords[:, :, 0] = 
tensor([[ 0,  0,  0,  0,  0,  0,  0, -1, -1, -1, -1, -1, -1, -1, -2, -2, -2, -2,-2, -2, -2, -3, -3, -3, -3, -3, -3, -3, -4, -4, -4, -4, -4, -4, -4, -5,-5, -5, -5, -5, -5, -5, -6, -6, -6, -6, -6, -6, -6],[ 0,  0,  0,  0,  0,  0,  0, -1, -1, -1, -1, -1, -1, -1, -2, -2, -2, -2,-2, -2, -2, -3, -3, -3, -3, -3, -3, -3, -4, -4, -4, -4, -4, -4, -4, -5,-5, -5, -5, -5, -5, -5, -6, -6, -6, -6, -6, -6, -6],...[ 6,  6,  6,  6,  6,  6,  6,  5,  5,  5,  5,  5,  5,  5,  4,  4,  4,  4,4,  4,  4,  3,  3,  3,  3,  3,  3,  3,  2,  2,  2,  2,  2,  2,  2,  1,1,  1,  1,  1,  1,  1,  0,  0,  0,  0,  0,  0,  0],[ 6,  6,  6,  6,  6,  6,  6,  5,  5,  5,  5,  5,  5,  5,  4,  4,  4,  4,4,  4,  4,  3,  3,  3,  3,  3,  3,  3,  2,  2,  2,  2,  2,  2,  2,  1,1,  1,  1,  1,  1,  1,  0,  0,  0,  0,  0,  0,  0],[ 6,  6,  6,  6,  6,  6,  6,  5,  5,  5,  5,  5,  5,  5,  4,  4,  4,  4,4,  4,  4,  3,  3,  3,  3,  3,  3,  3,  2,  2,  2,  2,  2,  2,  2,  1,1,  1,  1,  1,  1,  1,  0,  0,  0,  0,  0,  0,  0]])
relative_coords[:, :, 1] = 
tensor([[ 0, -1, -2, -3, -4, -5, -6,  0, -1, -2, -3, -4, -5, -6,  0, -1, -2, -3,-4, -5, -6,  0, -1, -2, -3, -4, -5, -6,  0, -1, -2, -3, -4, -5, -6,  0,-1, -2, -3, -4, -5, -6,  0, -1, -2, -3, -4, -5, -6],[ 1,  0, -1, -2, -3, -4, -5,  1,  0, -1, -2, -3, -4, -5,  1,  0, -1, -2,-3, -4, -5,  1,  0, -1, -2, -3, -4, -5,  1,  0, -1, -2, -3, -4, -5,  1,0, -1, -2, -3, -4, -5,  1,  0, -1, -2, -3, -4, -5],[ 2,  1,  0, -1, -2, -3, -4,  2,  1,  0, -1, -2, -3, -4,  2,  1,  0, -1,-2, -3, -4,  2,  1,  0, -1, -2, -3, -4,  2,  1,  0, -1, -2, -3, -4,  2,1,  0, -1, -2, -3, -4,  2,  1,  0, -1, -2, -3, -4],...[ 5,  4,  3,  2,  1,  0, -1,  5,  4,  3,  2,  1,  0, -1,  5,  4,  3,  2,1,  0, -1,  5,  4,  3,  2,  1,  0, -1,  5,  4,  3,  2,  1,  0, -1,  5,4,  3,  2,  1,  0, -1,  5,  4,  3,  2,  1,  0, -1],[ 6,  5,  4,  3,  2,  1,  0,  6,  5,  4,  3,  2,  1,  0,  6,  5,  4,  3,2,  1,  0,  6,  5,  4,  3,  2,  1,  0,  6,  5,  4,  3,  2,  1,  0,  6,5,  4,  3,  2,  1,  0,  6,  5,  4,  3,  2,  1,  0]])

第一步是横纵坐标都加6,避免距离值是负值。

relative_coords[:, :, 0] += self.window_size[0] - 1  # shift to start from 0
relative_coords[:, :, 1] += self.window_size[1] - 1

经过计算后,relative_coords变成:

然后对横坐标乘以2 * self.window_size[1] - 1。

relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1

经过计算后,relative_coords的值为

对横坐标乘以2 * self.window_size[1] - 1后,这里面的最大值为左下角的(156,12),156+12=168. 我们在最开始定义的relative_position_bias_table的shape也是(169, num_heads) 的,正好可以把relative_position_bias_table全部索引到。

        # define a parameter table of relative position biasself.relative_position_bias_table = nn.Parameter(torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads))  # 2*Wh-1 * 2*Ww-1, nH

6. 横纵坐标加和

relative_position_index = relative_coords.sum(-1)  # Wh*Ww, Wh*Ww
relative_position_index = 
tensor([[ 84,  83,  82,  81,  80,  79,  78,  71,  70,  69,  68,  67,  66,  65,58,  57,  56,  55,  54,  53,  52,  45,  44,  43,  42,  41,  40,  39,32,  31,  30,  29,  28,  27,  26,  19,  18,  17,  16,  15,  14,  13,6,   5,   4,   3,   2,   1,   0],[ 85,  84,  83,  82,  81,  80,  79,  72,  71,  70,  69,  68,  67,  66,59,  58,  57,  56,  55,  54,  53,  46,  45,  44,  43,  42,  41,  40,33,  32,  31,  30,  29,  28,  27,  20,  19,  18,  17,  16,  15,  14,7,   6,   5,   4,   3,   2,   1],...[167, 166, 165, 164, 163, 162, 161, 154, 153, 152, 151, 150, 149, 148,141, 140, 139, 138, 137, 136, 135, 128, 127, 126, 125, 124, 123, 122,115, 114, 113, 112, 111, 110, 109, 102, 101, 100,  99,  98,  97,  96,89,  88,  87,  86,  85,  84,  83],[168, 167, 166, 165, 164, 163, 162, 155, 154, 153, 152, 151, 150, 149,142, 141, 140, 139, 138, 137, 136, 129, 128, 127, 126, 125, 124, 123,116, 115, 114, 113, 112, 111, 110, 103, 102, 101, 100,  99,  98,  97,90,  89,  88,  87,  86,  85,  84]])

relative_position_index的shape是[49, 49]。这里面最大的值是168,最小的值是0,正好可以对最开始的self.relative_position_bias_table的值全部索引到,self.relative_position_bias_table的shape是[169,4],169个可索引的位置,4个头。

7. forward中的计算

relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1)  # Wh*Ww,Wh*Ww,nH
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()  # nH, Wh*Ww, Wh*Ww
attn = attn + relative_position_bias.unsqueeze(0)

将其拆分来看。我们把上面的代码拆分为:

relative_position_index_tmp = self.relative_position_index.view(-1)relative_position_bias_table_tmp = self.relative_position_bias_table[relative_position_index_tmp]relative_position_bias = relative_position_bias_table_tmp.view(self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1)  # Wh*Ww,Wh*Ww,nHrelative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()  # nH, Wh*Ww, Wh*Wwattn = attn + relative_position_bias.unsqueeze(0)
  • self.relative_position_index.view(-1)操作

上面已经说到relative_position_index的shape是[49, 49],view(-1)操作会将所有的值放在一个维度,所以假设relative_position_index_tmp = self.relative_position_index.view(-1),relative_position_index_tmp的shape是[2401],relative_position_index_tmp的值是

tensor([ 84,  83,  82,  81,  80,  79,  78,  71,  70,  69,  68,  67,  66,  65,58,  57,  56,  55,  54,  53,  52,  45,  44,  43,  42,  41,  40,  39,32,  31,  30,  29,  28,  27,  26,  19,  18,  17,  16,  15,  14,  13,6,   5,   4,   3,   2,   1,   0,  85,  84,  83,  82,  81,  80,  79,72,  71,  70,  69,  68,  67,  66,  59,  58,  57,  56,  55,  54,  53,46,  45,  44,  43,  42,  41,  40,  33,  32,  31,  30,  29,  28,  27,20,  19,  18,  17,  16,  15,  14,   7,   6,   5,   4,   3,   2,   1,86,  85,  84,  83,  82,  81,  80,  73,  72,  71,  70,  69,  68,  67,60,  59,  58,  57,  56,  55,  54,  47,  46,  45,  44,  43,  42,  41,...128, 127, 126, 125, 124, 123, 122, 115, 114, 113, 112, 111, 110, 109,102, 101, 100,  99,  98,  97,  96,  89,  88,  87,  86,  85,  84,  83,168, 167, 166, 165, 164, 163, 162, 155, 154, 153, 152, 151, 150, 149,142, 141, 140, 139, 138, 137, 136, 129, 128, 127, 126, 125, 124, 123,116, 115, 114, 113, 112, 111, 110, 103, 102, 101, 100,  99,  98,  97,90,  89,  88,  87,  86,  85,  84], device='cuda:0')
  • self.relative_position_bias_table[relative_position_index_tmp]操作

上面说到,self.relative_position_bias_table的shape是[169,4],在上面也有打印它的值。relative_position_index_tmp的shape是[2401]的索引值,所以此步操作就相当于对self.relative_position_bias_table做2401次索引,每次索引出来的都是shape为[1,4]的tensor,所以
relative_position_bias_table_tmp = self.relative_position_bias_table[relative_position_index_tmp]后,relative_position_bias_table_tmp的shape为[2401, 4],他的部分值展示如下:

tensor([[-2.7874e-02, -3.6342e-02,  2.8827e-02, -2.3673e-02],[ 4.1464e-02,  7.6075e-03,  5.3314e-03, -1.4585e-03],[-1.4711e-02, -1.3424e-02, -1.3756e-03, -1.3826e-03],[-3.6255e-02, -1.9680e-03,  1.6092e-02,  2.3690e-02],...[-2.0880e-02, -1.2093e-02,  4.1462e-02,  1.8901e-02],[-1.6958e-02, -6.0754e-03, -1.3342e-02, -7.5932e-04],[ 8.9761e-03, -1.1548e-02, -2.5437e-02,  1.5095e-02],[-2.7874e-02, -3.6342e-02,  2.8827e-02, -2.3673e-02]]
  • relative_position_bias_table_tmp.view(self.window_size[0] * self.window_size[1],self.window_size[0] * self.window_size[1], -1)操作

这一步就是将之前的view(-1)操作再转换回原始尺寸。所以relative_position_bias = relative_position_bias_table_tmp.view(self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1),relative_position_bias的shape是[49, 49, 4].

  • relative_position_bias.permute(2, 0, 1).contiguous()操作

上一步relative_position_bias的shape是[49, 49, 4],经过permute操作后,shape变为[4, 49, 49].

attn = attn + relative_position_bias.unsqueeze(0),attn经过attention计算后,shape是([6400, 4, 49, 49]),这里的relative_position_bias经过unsqueeze一下,正好可以和attn的结果相加。


http://www.ppmy.cn/devtools/152234.html

相关文章

51c自动驾驶~合集48

我自己的原文哦~ https://blog.51cto.com/whaosoft/13133866 #UDMC 考虑轨迹预测的统一决策控制框架 论文:https://arxiv.org/pdf/2501.02530 代码:​​https://github.com/henryhcliu/udmc_carla.git​​ 1. 摘要 当前的自动驾驶系统常常在确…

Java 8 Stream API

文章目录 Java 8 Stream API1. Stream2. Stream 的创建3. 常见的 Stream 操作3.1 中间操作3.2 终止操作 4. Stream 的并行操作 Java 8 Stream API Java 8 引入了 Stream API,使得对集合类(如 List、Set 等)的操作变得更加简洁和直观。Stream…

gitignore忽略已经提交过的

已经在.gitignore文件中添加了过滤规则来忽略bin和obj等文件夹,但这些文件夹仍然出现在提交中,可能是因为这些文件夹在添加.gitignore规则之前已经被提交到Git仓库中了。要解决这个问题,您需要从Git的索引中移除这些文件夹,并确保…

1161 Merging Linked Lists (25)

Given two singly linked lists L1​a1​→a2​→⋯→an−1​→an​ and L2​b1​→b2​→⋯→bm−1​→bm​. If n≥2m, you are supposed to reverse and merge the shorter one into the longer one to obtain a list like a1​→a2​→bm​→a3​→a4​→bm−1​⋯. For ex…

如何用selenium来链接并打开比特浏览器进行自动化操作(1)

前言 本文是该专栏的第76篇,后面会持续分享python爬虫干货知识,记得关注。 本文,笔者将基于“比特浏览器”,通过selenium来实现链接并打开比特浏览器,进行相关的“自动化”操作。 值得一提的是,在本专栏之前,笔者有详细介绍过“使用selenium或者pyppeteer(puppeteer)…

InVideo AI技术浅析(二):自然语言处理

InVideo AI的自然语言处理(NLP)模块是整个系统中的关键部分,负责处理和分析用户输入的文本数据,以实现智能化的视频生成和编辑功能。 1. 文本解析与理解 1.1 文本解析过程 文本解析是将用户输入的自然语言文本转换为机器可理解的格式的过程。解析过程可以分为以下几个步…

深度学习基础知识

深度学习是人工智能(AI)和机器学习(ML)领域的一个重要分支,以下是对深度学习基础知识的归纳: 一、定义与原理 定义:深度学习是一种使计算机能够从经验中学习并以概念层次结构的方式理解世界的机…

计算机网络 (44)电子邮件

一、概述 电子邮件(Electronic Mail,简称E-mail)是因特网上最早流行的应用之一,并且至今仍然是因特网上最重要、最实用的应用之一。它利用计算机技术和互联网,实现了信息的快速、便捷传递。与传统的邮政系统相比&#…