1. 生成类别矩阵如下

python">import torch
import torch.nn as nn
import torch.nn.functional as Ftorch.set_printoptions(precision=3, sci_mode=False)if __name__ == "__main__":run_code = 0a_matrix = torch.arange(4).reshape(2, 2) + 1b_matrix = torch.ones((2, 2))print(f"a_matrix=\n{a_matrix}")print(f"b_matrix=\n{b_matrix}")c_matrix = torch.kron(input=a_matrix, other=b_matrix)print(f"c_matrix=\n{c_matrix}")d_matrix = torch.arange(9).reshape(3, 3) + 1e_matrix = torch.ones((2, 2))f_matrix = torch.kron(input=d_matrix, other=e_matrix)print(f"d_matrix=\n{d_matrix}")print(f"e_matrix=\n{e_matrix}")print(f"f_matrix=\n{f_matrix}")g_matrix = f_matrix[1:-1, 1:-1]print(f"g_matrix=\n{g_matrix}")
python">a_matrix=
tensor([[1, 2],[3, 4]])
b_matrix=
tensor([[1., 1.],[1., 1.]])
c_matrix=
tensor([[1., 1., 2., 2.],[1., 1., 2., 2.],[3., 3., 4., 4.],[3., 3., 4., 4.]])
d_matrix=
tensor([[1, 2, 3],[4, 5, 6],[7, 8, 9]])
e_matrix=
tensor([[1., 1.],[1., 1.]])
f_matrix=
tensor([[1., 1., 2., 2., 3., 3.],[1., 1., 2., 2., 3., 3.],[4., 4., 5., 5., 6., 6.],[4., 4., 5., 5., 6., 6.],[7., 7., 8., 8., 9., 9.],[7., 7., 8., 8., 9., 9.]])
g_matrix=
tensor([[1., 2., 2., 3.],[4., 5., 5., 6.],[4., 5., 5., 6.],[7., 8., 8., 9.]])
3. 循环移动矩阵
python">import torch
import torch.nn as nn
import torch.nn.functional as F
import mathtorch.set_printoptions(precision=3, sci_mode=False)class WindowMatrix(object):def __init__(self, num_patch=4, size=2):self.num_patch = num_patchself.size = sizeself.width = self.num_patchself.height = self.size * self.sizeself._result = torch.zeros((self.width, self.height))@propertydef result(self):a_size = int(math.sqrt(self.num_patch))a_matrix = torch.arange(self.num_patch).reshape(a_size, a_size) + 1b_matrix = torch.ones(self.size, self.size)self._result = torch.kron(input=a_matrix, other=b_matrix)return self._resultclass ShiftedWindowMatrix(object):def __init__(self, num_patch=9, size=2):self.num_patch = num_patchself.size = sizeself.width = self.num_patchself.height = self.size * self.sizeself._result = torch.zeros((self.width, self.height))@propertydef result(self):a_size = int(math.sqrt(self.num_patch))a_matrix = torch.arange(self.num_patch).reshape(a_size, a_size) + 1b_matrix = torch.ones(self.size, self.size)my_result = torch.kron(input=a_matrix, other=b_matrix)self._result = my_result[1:-1, 1:-1]return self._resultclass RollShiftedWindowMatrix(object):def __init__(self, num_patch=9, size=2):self.num_patch = num_patchself.size = sizeself.width = self.num_patchself.height = self.size * self.sizeself._result = torch.zeros((self.width, self.height))@propertydef result(self):a_size = int(math.sqrt(self.num_patch))a_matrix = torch.arange(self.num_patch).reshape(a_size, a_size) + 1b_matrix = torch.ones(self.size, self.size)my_result = torch.kron(input=a_matrix, other=b_matrix)my_result = my_result[1:-1, 1:-1]roll_result = torch.roll(input=my_result, shifts=(-1, -1), dims=(-1, -2))self._result = roll_resultreturn self._resultclass BackRollShiftedWindowMatrix(object):def __init__(self, num_patch=9, size=2):self.num_patch = num_patchself.size = sizeself.width = self.num_patchself.height = self.size * self.sizeself._result = torch.zeros((self.width, self.height))@propertydef result(self):a_size = int(math.sqrt(self.num_patch))a_matrix = torch.arange(self.num_patch).reshape(a_size, a_size) + 1b_matrix = torch.ones(self.size, self.size)my_result = torch.kron(input=a_matrix, other=b_matrix)my_result = my_result[1:-1, 1:-1]roll_result = torch.roll(input=my_result, shifts=(-1, -1), dims=(-1, -2))print(f"roll_result=\n{roll_result}")roll_result = torch.roll(input=roll_result, shifts=(1, 1), dims=(-1, -2))self._result = roll_resultreturn self._resultif __name__ == "__main__":run_code = 0my_window_matrix = WindowMatrix()my_window_matrix_result = my_window_matrix.resultprint(f"my_window_matrix_result=\n{my_window_matrix_result}")shifted_window_matrix = ShiftedWindowMatrix()shifed_window_matrix_result = shifted_window_matrix.resultprint(f"shifed_window_matrix_result=\n{shifed_window_matrix_result}")roll_shifted_window_matrix = RollShiftedWindowMatrix()roll_shifed_window_matrix_result = roll_shifted_window_matrix.resultprint(f"roll_shifed_window_matrix_result=\n{roll_shifed_window_matrix_result}")Back_roll_shifted_window_matrix = BackRollShiftedWindowMatrix()back_roll_shifed_window_matrix_result = Back_roll_shifted_window_matrix.resultprint(f"back_roll_shifed_window_matrix_result=\n{back_roll_shifed_window_matrix_result}")
python">my_window_matrix_result=
tensor([[1., 1., 2., 2.],[1., 1., 2., 2.],[3., 3., 4., 4.],[3., 3., 4., 4.]])
shifed_window_matrix_result=
tensor([[1., 2., 2., 3.],[4., 5., 5., 6.],[4., 5., 5., 6.],[7., 8., 8., 9.]])
roll_shifed_window_matrix_result=
tensor([[5., 5., 6., 4.],[5., 5., 6., 4.],[8., 8., 9., 7.],[2., 2., 3., 1.]])
roll_result=
tensor([[5., 5., 6., 4.],[5., 5., 6., 4.],[8., 8., 9., 7.],[2., 2., 3., 1.]])
back_roll_shifed_window_matrix_result=
tensor([[1., 2., 2., 3.],[4., 5., 5., 6.],[4., 5., 5., 6.],[7., 8., 8., 9.]])