文章目录
- 1. excel 示意
- 2. pytorch代码
- 3. window mhsa
1. excel 示意
将一个三维矩阵按照window的大小进行拆分成多块2x2窗口矩阵,具体如下图所示
pytorch_4">2. pytorch代码
- pytorch源码
import torch
import torch.nn as nn
import torch.nn.functional as Ftorch.set_printoptions(precision=3, sci_mode=False)if __name__ == "__main__":run_code = 0batch_size = 2seq_len = 4model_dim = 6patch_total = batch_size * seq_len * model_dimpatch = torch.arange(patch_total).reshape((batch_size, seq_len, model_dim)).to(torch.float32)print(f"patch.shape=\n{patch.shape}")print(f"patch=\n{patch}")patch_unfold = F.unfold(input=patch, kernel_size=(2, 2), stride=(2, 2))print(f"patch_unfold.shape=\n{patch_unfold.shape}")print(f"patch_unfold=\n{patch_unfold}")# patch_unfold = patch_unfold.transpose(-1, -2)print(f"patch_unfold=\n{patch_unfold}")patch_nums = patch_unfold.reshape(batch_size, 4, 6)print(f"patch_nums=\n{patch_nums}")patch_nums_new = patch_nums.transpose(-1, -2)print(f"patch_nums_new.shape=\n{patch_nums_new.shape}")print(f"patch_nums_new=\n{patch_nums_new}")patch_nums_final = patch_nums_new.reshape(12, 2, 2)print(f"patch_nums_final.shape=\n{patch_nums_final.shape}")print(f"patch_nums_final=\n{patch_nums_final}")
- 结果:
patch.shape=
torch.Size([2, 4, 6])
patch=
tensor([[[ 0., 1., 2., 3., 4., 5.],[ 6., 7., 8., 9., 10., 11.],[12., 13., 14., 15., 16., 17.],[18., 19., 20., 21., 22., 23.]],[[24., 25., 26., 27., 28., 29.],[30., 31., 32., 33., 34., 35.],[36., 37., 38., 39., 40., 41.],[42., 43., 44., 45., 46., 47.]]])
patch_unfold.shape=
torch.Size([8, 6])
patch_unfold=
tensor([[ 0., 2., 4., 12., 14., 16.],[ 1., 3., 5., 13., 15., 17.],[ 6., 8., 10., 18., 20., 22.],[ 7., 9., 11., 19., 21., 23.],[24., 26., 28., 36., 38., 40.],[25., 27., 29., 37., 39., 41.],[30., 32., 34., 42., 44., 46.],[31., 33., 35., 43., 45., 47.]])
patch_unfold=
tensor([[ 0., 2., 4., 12., 14., 16.],[ 1., 3., 5., 13., 15., 17.],[ 6., 8., 10., 18., 20., 22.],[ 7., 9., 11., 19., 21., 23.],[24., 26., 28., 36., 38., 40.],[25., 27., 29., 37., 39., 41.],[30., 32., 34., 42., 44., 46.],[31., 33., 35., 43., 45., 47.]])
patch_nums=
tensor([[[ 0., 2., 4., 12., 14., 16.],[ 1., 3., 5., 13., 15., 17.],[ 6., 8., 10., 18., 20., 22.],[ 7., 9., 11., 19., 21., 23.]],[[24., 26., 28., 36., 38., 40.],[25., 27., 29., 37., 39., 41.],[30., 32., 34., 42., 44., 46.],[31., 33., 35., 43., 45., 47.]]])
patch_nums_new.shape=
torch.Size([2, 6, 4])
patch_nums_new=
tensor([[[ 0., 1., 6., 7.],[ 2., 3., 8., 9.],[ 4., 5., 10., 11.],[12., 13., 18., 19.],[14., 15., 20., 21.],[16., 17., 22., 23.]],[[24., 25., 30., 31.],[26., 27., 32., 33.],[28., 29., 34., 35.],[36., 37., 42., 43.],[38., 39., 44., 45.],[40., 41., 46., 47.]]])
patch_nums_final.shape=
torch.Size([12, 2, 2])
patch_nums_final=
tensor([[[ 0., 1.],[ 6., 7.]],[[ 2., 3.],[ 8., 9.]],[[ 4., 5.],[10., 11.]],[[12., 13.],[18., 19.]],[[14., 15.],[20., 21.]],[[16., 17.],[22., 23.]],[[24., 25.],[30., 31.]],[[26., 27.],[32., 33.]],[[28., 29.],[34., 35.]],[[36., 37.],[42., 43.]],[[38., 39.],[44., 45.]],[[40., 41.],[46., 47.]]])
3. window mhsa
- excel 示意图
- pytorch
import torch
import torch.nn as nn
import torch.nn.functional as Ftorch.set_printoptions(precision=3, sci_mode=False)if __name__ == "__main__":run_code = 0bs = 2num_patch = 16patch_depth = 4window_size = 2image_height = image_width = 4num_patch_in_window = window_size * window_sizepatch_total = bs * num_patch * patch_depthpatch_embedding = torch.arange(patch_total).reshape((bs, num_patch, patch_depth)).to(torch.float32)print(f"patch_embedding.shape=\n{patch_embedding.shape}")print(f"patch_embedding=\n{patch_embedding}")patch_embedding = patch_embedding.transpose(-1, -2)patch = patch_embedding.reshape(bs, patch_depth, image_height, image_width)print(f"patch=\n{patch}")window = F.unfold(patch, kernel_size=(window_size, window_size), stride=(window_size, window_size)).transpose(-1,-2)print(f"window.shape=\n{window.shape}")print(f"window=\n{window}")bs, num_window, patch_depth_times_num_patch_in_window = window.shapewindow = window.reshape(bs*num_window,patch_depth,num_patch_in_window).transpose(-1,-2)print(f"window.shape=\n{window.shape}")print(f"window=\n{window}")
- 结果:
patch_embedding.shape=
torch.Size([2, 16, 4])
patch_embedding=
tensor([[[ 0., 1., 2., 3.],[ 4., 5., 6., 7.],[ 8., 9., 10., 11.],[ 12., 13., 14., 15.],[ 16., 17., 18., 19.],[ 20., 21., 22., 23.],[ 24., 25., 26., 27.],[ 28., 29., 30., 31.],[ 32., 33., 34., 35.],[ 36., 37., 38., 39.],[ 40., 41., 42., 43.],[ 44., 45., 46., 47.],[ 48., 49., 50., 51.],[ 52., 53., 54., 55.],[ 56., 57., 58., 59.],[ 60., 61., 62., 63.]],[[ 64., 65., 66., 67.],[ 68., 69., 70., 71.],[ 72., 73., 74., 75.],[ 76., 77., 78., 79.],[ 80., 81., 82., 83.],[ 84., 85., 86., 87.],[ 88., 89., 90., 91.],[ 92., 93., 94., 95.],[ 96., 97., 98., 99.],[100., 101., 102., 103.],[104., 105., 106., 107.],[108., 109., 110., 111.],[112., 113., 114., 115.],[116., 117., 118., 119.],[120., 121., 122., 123.],[124., 125., 126., 127.]]])
patch=
tensor([[[[ 0., 4., 8., 12.],[ 16., 20., 24., 28.],[ 32., 36., 40., 44.],[ 48., 52., 56., 60.]],[[ 1., 5., 9., 13.],[ 17., 21., 25., 29.],[ 33., 37., 41., 45.],[ 49., 53., 57., 61.]],[[ 2., 6., 10., 14.],[ 18., 22., 26., 30.],[ 34., 38., 42., 46.],[ 50., 54., 58., 62.]],[[ 3., 7., 11., 15.],[ 19., 23., 27., 31.],[ 35., 39., 43., 47.],[ 51., 55., 59., 63.]]],[[[ 64., 68., 72., 76.],[ 80., 84., 88., 92.],[ 96., 100., 104., 108.],[112., 116., 120., 124.]],[[ 65., 69., 73., 77.],[ 81., 85., 89., 93.],[ 97., 101., 105., 109.],[113., 117., 121., 125.]],[[ 66., 70., 74., 78.],[ 82., 86., 90., 94.],[ 98., 102., 106., 110.],[114., 118., 122., 126.]],[[ 67., 71., 75., 79.],[ 83., 87., 91., 95.],[ 99., 103., 107., 111.],[115., 119., 123., 127.]]]])
window.shape=
torch.Size([2, 4, 16])
window=
tensor([[[ 0., 4., 16., 20., 1., 5., 17., 21., 2., 6., 18.,22., 3., 7., 19., 23.],[ 8., 12., 24., 28., 9., 13., 25., 29., 10., 14., 26.,30., 11., 15., 27., 31.],[ 32., 36., 48., 52., 33., 37., 49., 53., 34., 38., 50.,54., 35., 39., 51., 55.],[ 40., 44., 56., 60., 41., 45., 57., 61., 42., 46., 58.,62., 43., 47., 59., 63.]],[[ 64., 68., 80., 84., 65., 69., 81., 85., 66., 70., 82.,86., 67., 71., 83., 87.],[ 72., 76., 88., 92., 73., 77., 89., 93., 74., 78., 90.,94., 75., 79., 91., 95.],[ 96., 100., 112., 116., 97., 101., 113., 117., 98., 102., 114.,118., 99., 103., 115., 119.],[104., 108., 120., 124., 105., 109., 121., 125., 106., 110., 122.,126., 107., 111., 123., 127.]]])
window.shape=
torch.Size([8, 4, 4])
window=
tensor([[[ 0., 1., 2., 3.],[ 4., 5., 6., 7.],[ 16., 17., 18., 19.],[ 20., 21., 22., 23.]],[[ 8., 9., 10., 11.],[ 12., 13., 14., 15.],[ 24., 25., 26., 27.],[ 28., 29., 30., 31.]],[[ 32., 33., 34., 35.],[ 36., 37., 38., 39.],[ 48., 49., 50., 51.],[ 52., 53., 54., 55.]],[[ 40., 41., 42., 43.],[ 44., 45., 46., 47.],[ 56., 57., 58., 59.],[ 60., 61., 62., 63.]],[[ 64., 65., 66., 67.],[ 68., 69., 70., 71.],[ 80., 81., 82., 83.],[ 84., 85., 86., 87.]],[[ 72., 73., 74., 75.],[ 76., 77., 78., 79.],[ 88., 89., 90., 91.],[ 92., 93., 94., 95.]],[[ 96., 97., 98., 99.],[100., 101., 102., 103.],[112., 113., 114., 115.],[116., 117., 118., 119.]],[[104., 105., 106., 107.],[108., 109., 110., 111.],[120., 121., 122., 123.],[124., 125., 126., 127.]]])