文章目录
- 1. 相对位置矩阵2d
- 2. kron运算
1. 相对位置矩阵2d
在swin-transformer中,我们会计算每个patch之间的相对位置,那么我们看到有一连串的拉伸和相减,直接贴代码:
python">import torch
import torch.nn as nntorch.set_printoptions(precision=3, sci_mode=False,threshold=torch.inf)if __name__ == "__main__":run_code = 2x_len = 5y_len = 5x_tensor = torch.arange(x_len)y_tensor = torch.arange(y_len)x_meshgrid, y_meshgrid = torch.meshgrid(x_tensor, y_tensor)print(f"x_tensor=\n{x_tensor}")print(f"y_tensor=\n{y_tensor}")print(f"x_meshgrid=\n{x_meshgrid}")print(f"x_meshgrid.shape=\n{x_meshgrid.shape}")print(f"y_meshgrid.shape=\n{y_meshgrid.shape}")print(f"y_meshgrid=\n{y_meshgrid}")stack_meshgrid = torch.stack(torch.meshgrid(x_tensor, y_tensor))print(f"stack_meshgrid.shape=\n{stack_meshgrid.shape}")print(f"stack_meshgrid=\n{stack_meshgrid}")stack_meshgrid_flatten = torch.flatten(stack_meshgrid, 1)print(f"stack_meshgrid_flatten.shape=\n{stack_meshgrid_flatten.shape}")print(f"stack_meshgrid_flatten=\n{stack_meshgrid_flatten}")stack_meshgrid_flatten_1 = stack_meshgrid_flatten[:, None, :]stack_meshgrid_flatten_2 = stack_meshgrid_flatten[:, :, None]relative_coords_bias = stack_meshgrid_flatten_2 - stack_meshgrid_flatten_1print(f"stack_meshgrid_flatten_1=\n{stack_meshgrid_flatten_1}")print(f"stack_meshgrid_flatten_2=\n{stack_meshgrid_flatten_2}")print(f"relative_coords_bias=\n{relative_coords_bias}")relative_coords_bias[0, :, :] += x_lenrelative_coords_bias[1, :, :] += y_lenprint(f"relative_coords_bias=\n{relative_coords_bias}")
- result:
python">x_tensor=
tensor([0, 1, 2, 3, 4])
y_tensor=
tensor([0, 1, 2, 3, 4])
x_meshgrid=
tensor([[0, 0, 0, 0, 0],[1, 1, 1, 1, 1],[2, 2, 2, 2, 2],[3, 3, 3, 3, 3],[4, 4, 4, 4, 4]])
x_meshgrid.shape=
torch.Size([5, 5])
y_meshgrid.shape=
torch.Size([5, 5])
y_meshgrid=
tensor([[0, 1, 2, 3, 4],[0, 1, 2, 3, 4],[0, 1, 2, 3, 4],[0, 1, 2, 3, 4],[0, 1, 2, 3, 4]])
stack_meshgrid.shape=
torch.Size([2, 5, 5])
stack_meshgrid=
tensor([[[0, 0, 0, 0, 0],[1, 1, 1, 1, 1],[2, 2, 2, 2, 2],[3, 3, 3, 3, 3],[4, 4, 4, 4, 4]],[[0, 1, 2, 3, 4],[0, 1, 2, 3, 4],[0, 1, 2, 3, 4],[0, 1, 2, 3, 4],[0, 1, 2, 3, 4]]])
stack_meshgrid_flatten.shape=
torch.Size([2, 25])
stack_meshgrid_flatten=
tensor([[0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 4, 4, 4, 4,4],[0, 1, 2, 3, 4, 0, 1, 2, 3, 4, 0, 1, 2, 3, 4, 0, 1, 2, 3, 4, 0, 1, 2, 3,4]])
stack_meshgrid_flatten_1=
tensor([[[0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 4, 4, 4,4, 4]],[[0, 1, 2, 3, 4, 0, 1, 2, 3, 4, 0, 1, 2, 3, 4, 0, 1, 2, 3, 4, 0, 1, 2,3, 4]]])
stack_meshgrid_flatten_2=
tensor([[[0],[0],[0],[0],[0],[1],[1],[1],[1],[1],[2],[2],[2],[2],[2],[3],[3],[3],[3],[3],[4],[4],[4],[4],[4]],[[0],[1],[2],[3],[4],[0],[1],[2],[3],[4],[0],[1],[2],[3],[4],[0],[1],[2],[3],[4],[0],[1],[2],[3],[4]]])
relative_coords_bias=
tensor([[[ 0, 0, 0, 0, 0, -1, -1, -1, -1, -1, -2, -2, -2, -2, -2, -3, -3,-3, -3, -3, -4, -4, -4, -4, -4],[ 0, 0, 0, 0, 0, -1, -1, -1, -1, -1, -2, -2, -2, -2, -2, -3, -3,-3, -3, -3, -4, -4, -4, -4, -4],[ 0, 0, 0, 0, 0, -1, -1, -1, -1, -1, -2, -2, -2, -2, -2, -3, -3,-3, -3, -3, -4, -4, -4, -4, -4],[ 0, 0, 0, 0, 0, -1, -1, -1, -1, -1, -2, -2, -2, -2, -2, -3, -3,-3, -3, -3, -4, -4, -4, -4, -4],[ 0, 0, 0, 0, 0, -1, -1, -1, -1, -1, -2, -2, -2, -2, -2, -3, -3,-3, -3, -3, -4, -4, -4, -4, -4],[ 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, -1, -1, -1, -1, -1, -2, -2,-2, -2, -2, -3, -3, -3, -3, -3],[ 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, -1, -1, -1, -1, -1, -2, -2,-2, -2, -2, -3, -3, -3, -3, -3],[ 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, -1, -1, -1, -1, -1, -2, -2,-2, -2, -2, -3, -3, -3, -3, -3],[ 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, -1, -1, -1, -1, -1, -2, -2,-2, -2, -2, -3, -3, -3, -3, -3],[ 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, -1, -1, -1, -1, -1, -2, -2,-2, -2, -2, -3, -3, -3, -3, -3],[ 2, 2, 2, 2, 2, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, -1, -1,-1, -1, -1, -2, -2, -2, -2, -2],[ 2, 2, 2, 2, 2, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, -1, -1,-1, -1, -1, -2, -2, -2, -2, -2],[ 2, 2, 2, 2, 2, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, -1, -1,-1, -1, -1, -2, -2, -2, -2, -2],[ 2, 2, 2, 2, 2, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, -1, -1,-1, -1, -1, -2, -2, -2, -2, -2],[ 2, 2, 2, 2, 2, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, -1, -1,-1, -1, -1, -2, -2, -2, -2, -2],[ 3, 3, 3, 3, 3, 2, 2, 2, 2, 2, 1, 1, 1, 1, 1, 0, 0,0, 0, 0, -1, -1, -1, -1, -1],[ 3, 3, 3, 3, 3, 2, 2, 2, 2, 2, 1, 1, 1, 1, 1, 0, 0,0, 0, 0, -1, -1, -1, -1, -1],[ 3, 3, 3, 3, 3, 2, 2, 2, 2, 2, 1, 1, 1, 1, 1, 0, 0,0, 0, 0, -1, -1, -1, -1, -1],[ 3, 3, 3, 3, 3, 2, 2, 2, 2, 2, 1, 1, 1, 1, 1, 0, 0,0, 0, 0, -1, -1, -1, -1, -1],[ 3, 3, 3, 3, 3, 2, 2, 2, 2, 2, 1, 1, 1, 1, 1, 0, 0,0, 0, 0, -1, -1, -1, -1, -1],[ 4, 4, 4, 4, 4, 3, 3, 3, 3, 3, 2, 2, 2, 2, 2, 1, 1,1, 1, 1, 0, 0, 0, 0, 0],[ 4, 4, 4, 4, 4, 3, 3, 3, 3, 3, 2, 2, 2, 2, 2, 1, 1,1, 1, 1, 0, 0, 0, 0, 0],[ 4, 4, 4, 4, 4, 3, 3, 3, 3, 3, 2, 2, 2, 2, 2, 1, 1,1, 1, 1, 0, 0, 0, 0, 0],[ 4, 4, 4, 4, 4, 3, 3, 3, 3, 3, 2, 2, 2, 2, 2, 1, 1,1, 1, 1, 0, 0, 0, 0, 0],[ 4, 4, 4, 4, 4, 3, 3, 3, 3, 3, 2, 2, 2, 2, 2, 1, 1,1, 1, 1, 0, 0, 0, 0, 0]],[[ 0, -1, -2, -3, -4, 0, -1, -2, -3, -4, 0, -1, -2, -3, -4, 0, -1,-2, -3, -4, 0, -1, -2, -3, -4],[ 1, 0, -1, -2, -3, 1, 0, -1, -2, -3, 1, 0, -1, -2, -3, 1, 0,-1, -2, -3, 1, 0, -1, -2, -3],[ 2, 1, 0, -1, -2, 2, 1, 0, -1, -2, 2, 1, 0, -1, -2, 2, 1,0, -1, -2, 2, 1, 0, -1, -2],[ 3, 2, 1, 0, -1, 3, 2, 1, 0, -1, 3, 2, 1, 0, -1, 3, 2,1, 0, -1, 3, 2, 1, 0, -1],[ 4, 3, 2, 1, 0, 4, 3, 2, 1, 0, 4, 3, 2, 1, 0, 4, 3,2, 1, 0, 4, 3, 2, 1, 0],[ 0, -1, -2, -3, -4, 0, -1, -2, -3, -4, 0, -1, -2, -3, -4, 0, -1,-2, -3, -4, 0, -1, -2, -3, -4],[ 1, 0, -1, -2, -3, 1, 0, -1, -2, -3, 1, 0, -1, -2, -3, 1, 0,-1, -2, -3, 1, 0, -1, -2, -3],[ 2, 1, 0, -1, -2, 2, 1, 0, -1, -2, 2, 1, 0, -1, -2, 2, 1,0, -1, -2, 2, 1, 0, -1, -2],[ 3, 2, 1, 0, -1, 3, 2, 1, 0, -1, 3, 2, 1, 0, -1, 3, 2,1, 0, -1, 3, 2, 1, 0, -1],[ 4, 3, 2, 1, 0, 4, 3, 2, 1, 0, 4, 3, 2, 1, 0, 4, 3,2, 1, 0, 4, 3, 2, 1, 0],[ 0, -1, -2, -3, -4, 0, -1, -2, -3, -4, 0, -1, -2, -3, -4, 0, -1,-2, -3, -4, 0, -1, -2, -3, -4],[ 1, 0, -1, -2, -3, 1, 0, -1, -2, -3, 1, 0, -1, -2, -3, 1, 0,-1, -2, -3, 1, 0, -1, -2, -3],[ 2, 1, 0, -1, -2, 2, 1, 0, -1, -2, 2, 1, 0, -1, -2, 2, 1,0, -1, -2, 2, 1, 0, -1, -2],[ 3, 2, 1, 0, -1, 3, 2, 1, 0, -1, 3, 2, 1, 0, -1, 3, 2,1, 0, -1, 3, 2, 1, 0, -1],[ 4, 3, 2, 1, 0, 4, 3, 2, 1, 0, 4, 3, 2, 1, 0, 4, 3,2, 1, 0, 4, 3, 2, 1, 0],[ 0, -1, -2, -3, -4, 0, -1, -2, -3, -4, 0, -1, -2, -3, -4, 0, -1,-2, -3, -4, 0, -1, -2, -3, -4],[ 1, 0, -1, -2, -3, 1, 0, -1, -2, -3, 1, 0, -1, -2, -3, 1, 0,-1, -2, -3, 1, 0, -1, -2, -3],[ 2, 1, 0, -1, -2, 2, 1, 0, -1, -2, 2, 1, 0, -1, -2, 2, 1,0, -1, -2, 2, 1, 0, -1, -2],[ 3, 2, 1, 0, -1, 3, 2, 1, 0, -1, 3, 2, 1, 0, -1, 3, 2,1, 0, -1, 3, 2, 1, 0, -1],[ 4, 3, 2, 1, 0, 4, 3, 2, 1, 0, 4, 3, 2, 1, 0, 4, 3,2, 1, 0, 4, 3, 2, 1, 0],[ 0, -1, -2, -3, -4, 0, -1, -2, -3, -4, 0, -1, -2, -3, -4, 0, -1,-2, -3, -4, 0, -1, -2, -3, -4],[ 1, 0, -1, -2, -3, 1, 0, -1, -2, -3, 1, 0, -1, -2, -3, 1, 0,-1, -2, -3, 1, 0, -1, -2, -3],[ 2, 1, 0, -1, -2, 2, 1, 0, -1, -2, 2, 1, 0, -1, -2, 2, 1,0, -1, -2, 2, 1, 0, -1, -2],[ 3, 2, 1, 0, -1, 3, 2, 1, 0, -1, 3, 2, 1, 0, -1, 3, 2,1, 0, -1, 3, 2, 1, 0, -1],[ 4, 3, 2, 1, 0, 4, 3, 2, 1, 0, 4, 3, 2, 1, 0, 4, 3,2, 1, 0, 4, 3, 2, 1, 0]]])
relative_coords_bias=
tensor([[[5, 5, 5, 5, 5, 4, 4, 4, 4, 4, 3, 3, 3, 3, 3, 2, 2, 2, 2, 2, 1, 1, 1,1, 1],[5, 5, 5, 5, 5, 4, 4, 4, 4, 4, 3, 3, 3, 3, 3, 2, 2, 2, 2, 2, 1, 1, 1,1, 1],[5, 5, 5, 5, 5, 4, 4, 4, 4, 4, 3, 3, 3, 3, 3, 2, 2, 2, 2, 2, 1, 1, 1,1, 1],[5, 5, 5, 5, 5, 4, 4, 4, 4, 4, 3, 3, 3, 3, 3, 2, 2, 2, 2, 2, 1, 1, 1,1, 1],[5, 5, 5, 5, 5, 4, 4, 4, 4, 4, 3, 3, 3, 3, 3, 2, 2, 2, 2, 2, 1, 1, 1,1, 1],[6, 6, 6, 6, 6, 5, 5, 5, 5, 5, 4, 4, 4, 4, 4, 3, 3, 3, 3, 3, 2, 2, 2,2, 2],[6, 6, 6, 6, 6, 5, 5, 5, 5, 5, 4, 4, 4, 4, 4, 3, 3, 3, 3, 3, 2, 2, 2,2, 2],[6, 6, 6, 6, 6, 5, 5, 5, 5, 5, 4, 4, 4, 4, 4, 3, 3, 3, 3, 3, 2, 2, 2,2, 2],[6, 6, 6, 6, 6, 5, 5, 5, 5, 5, 4, 4, 4, 4, 4, 3, 3, 3, 3, 3, 2, 2, 2,2, 2],[6, 6, 6, 6, 6, 5, 5, 5, 5, 5, 4, 4, 4, 4, 4, 3, 3, 3, 3, 3, 2, 2, 2,2, 2],[7, 7, 7, 7, 7, 6, 6, 6, 6, 6, 5, 5, 5, 5, 5, 4, 4, 4, 4, 4, 3, 3, 3,3, 3],[7, 7, 7, 7, 7, 6, 6, 6, 6, 6, 5, 5, 5, 5, 5, 4, 4, 4, 4, 4, 3, 3, 3,3, 3],[7, 7, 7, 7, 7, 6, 6, 6, 6, 6, 5, 5, 5, 5, 5, 4, 4, 4, 4, 4, 3, 3, 3,3, 3],[7, 7, 7, 7, 7, 6, 6, 6, 6, 6, 5, 5, 5, 5, 5, 4, 4, 4, 4, 4, 3, 3, 3,3, 3],[7, 7, 7, 7, 7, 6, 6, 6, 6, 6, 5, 5, 5, 5, 5, 4, 4, 4, 4, 4, 3, 3, 3,3, 3],[8, 8, 8, 8, 8, 7, 7, 7, 7, 7, 6, 6, 6, 6, 6, 5, 5, 5, 5, 5, 4, 4, 4,4, 4],[8, 8, 8, 8, 8, 7, 7, 7, 7, 7, 6, 6, 6, 6, 6, 5, 5, 5, 5, 5, 4, 4, 4,4, 4],[8, 8, 8, 8, 8, 7, 7, 7, 7, 7, 6, 6, 6, 6, 6, 5, 5, 5, 5, 5, 4, 4, 4,4, 4],[8, 8, 8, 8, 8, 7, 7, 7, 7, 7, 6, 6, 6, 6, 6, 5, 5, 5, 5, 5, 4, 4, 4,4, 4],[8, 8, 8, 8, 8, 7, 7, 7, 7, 7, 6, 6, 6, 6, 6, 5, 5, 5, 5, 5, 4, 4, 4,4, 4],[9, 9, 9, 9, 9, 8, 8, 8, 8, 8, 7, 7, 7, 7, 7, 6, 6, 6, 6, 6, 5, 5, 5,5, 5],[9, 9, 9, 9, 9, 8, 8, 8, 8, 8, 7, 7, 7, 7, 7, 6, 6, 6, 6, 6, 5, 5, 5,5, 5],[9, 9, 9, 9, 9, 8, 8, 8, 8, 8, 7, 7, 7, 7, 7, 6, 6, 6, 6, 6, 5, 5, 5,5, 5],[9, 9, 9, 9, 9, 8, 8, 8, 8, 8, 7, 7, 7, 7, 7, 6, 6, 6, 6, 6, 5, 5, 5,5, 5],[9, 9, 9, 9, 9, 8, 8, 8, 8, 8, 7, 7, 7, 7, 7, 6, 6, 6, 6, 6, 5, 5, 5,5, 5]],[[5, 4, 3, 2, 1, 5, 4, 3, 2, 1, 5, 4, 3, 2, 1, 5, 4, 3, 2, 1, 5, 4, 3,2, 1],[6, 5, 4, 3, 2, 6, 5, 4, 3, 2, 6, 5, 4, 3, 2, 6, 5, 4, 3, 2, 6, 5, 4,3, 2],[7, 6, 5, 4, 3, 7, 6, 5, 4, 3, 7, 6, 5, 4, 3, 7, 6, 5, 4, 3, 7, 6, 5,4, 3],[8, 7, 6, 5, 4, 8, 7, 6, 5, 4, 8, 7, 6, 5, 4, 8, 7, 6, 5, 4, 8, 7, 6,5, 4],[9, 8, 7, 6, 5, 9, 8, 7, 6, 5, 9, 8, 7, 6, 5, 9, 8, 7, 6, 5, 9, 8, 7,6, 5],[5, 4, 3, 2, 1, 5, 4, 3, 2, 1, 5, 4, 3, 2, 1, 5, 4, 3, 2, 1, 5, 4, 3,2, 1],[6, 5, 4, 3, 2, 6, 5, 4, 3, 2, 6, 5, 4, 3, 2, 6, 5, 4, 3, 2, 6, 5, 4,3, 2],[7, 6, 5, 4, 3, 7, 6, 5, 4, 3, 7, 6, 5, 4, 3, 7, 6, 5, 4, 3, 7, 6, 5,4, 3],[8, 7, 6, 5, 4, 8, 7, 6, 5, 4, 8, 7, 6, 5, 4, 8, 7, 6, 5, 4, 8, 7, 6,5, 4],[9, 8, 7, 6, 5, 9, 8, 7, 6, 5, 9, 8, 7, 6, 5, 9, 8, 7, 6, 5, 9, 8, 7,6, 5],[5, 4, 3, 2, 1, 5, 4, 3, 2, 1, 5, 4, 3, 2, 1, 5, 4, 3, 2, 1, 5, 4, 3,2, 1],[6, 5, 4, 3, 2, 6, 5, 4, 3, 2, 6, 5, 4, 3, 2, 6, 5, 4, 3, 2, 6, 5, 4,3, 2],[7, 6, 5, 4, 3, 7, 6, 5, 4, 3, 7, 6, 5, 4, 3, 7, 6, 5, 4, 3, 7, 6, 5,4, 3],[8, 7, 6, 5, 4, 8, 7, 6, 5, 4, 8, 7, 6, 5, 4, 8, 7, 6, 5, 4, 8, 7, 6,5, 4],[9, 8, 7, 6, 5, 9, 8, 7, 6, 5, 9, 8, 7, 6, 5, 9, 8, 7, 6, 5, 9, 8, 7,6, 5],[5, 4, 3, 2, 1, 5, 4, 3, 2, 1, 5, 4, 3, 2, 1, 5, 4, 3, 2, 1, 5, 4, 3,2, 1],[6, 5, 4, 3, 2, 6, 5, 4, 3, 2, 6, 5, 4, 3, 2, 6, 5, 4, 3, 2, 6, 5, 4,3, 2],[7, 6, 5, 4, 3, 7, 6, 5, 4, 3, 7, 6, 5, 4, 3, 7, 6, 5, 4, 3, 7, 6, 5,4, 3],[8, 7, 6, 5, 4, 8, 7, 6, 5, 4, 8, 7, 6, 5, 4, 8, 7, 6, 5, 4, 8, 7, 6,5, 4],[9, 8, 7, 6, 5, 9, 8, 7, 6, 5, 9, 8, 7, 6, 5, 9, 8, 7, 6, 5, 9, 8, 7,6, 5],[5, 4, 3, 2, 1, 5, 4, 3, 2, 1, 5, 4, 3, 2, 1, 5, 4, 3, 2, 1, 5, 4, 3,2, 1],[6, 5, 4, 3, 2, 6, 5, 4, 3, 2, 6, 5, 4, 3, 2, 6, 5, 4, 3, 2, 6, 5, 4,3, 2],[7, 6, 5, 4, 3, 7, 6, 5, 4, 3, 7, 6, 5, 4, 3, 7, 6, 5, 4, 3, 7, 6, 5,4, 3],[8, 7, 6, 5, 4, 8, 7, 6, 5, 4, 8, 7, 6, 5, 4, 8, 7, 6, 5, 4, 8, 7, 6,5, 4],[9, 8, 7, 6, 5, 9, 8, 7, 6, 5, 9, 8, 7, 6, 5, 9, 8, 7, 6, 5, 9, 8, 7,6, 5]]])
2. kron运算
在结果中,我们发现很多重复的值,这就让我联想到kron运算。
python">import torch
import torch.nn as nntorch.set_printoptions(precision=3, sci_mode=False)if __name__ == '__main__':run_code = 0height = 5width = 5a_vector = torch.arange(width).to(torch.float).reshape(-1, 1)a_ones = torch.ones(1, width)a_matrix = a_vector @ a_onesprint(f"a_matrix=\n{a_matrix}")b_matrix = a_matrix - a_matrix.Tprint(f"b_matrix=\n{b_matrix}")b_matrix_ones = torch.ones_like(b_matrix)ab_kron = torch.kron(b_matrix,b_matrix_ones)print(f"ab_kron=\n{ab_kron}")final_ab = ab_kron+5print(f"final_ab=\n{final_ab}")
- result:
python">a_matrix=
tensor([[0., 0., 0., 0., 0.],[1., 1., 1., 1., 1.],[2., 2., 2., 2., 2.],[3., 3., 3., 3., 3.],[4., 4., 4., 4., 4.]])
b_matrix=
tensor([[ 0., -1., -2., -3., -4.],[ 1., 0., -1., -2., -3.],[ 2., 1., 0., -1., -2.],[ 3., 2., 1., 0., -1.],[ 4., 3., 2., 1., 0.]])
ab_kron=
tensor([[ 0., 0., 0., 0., 0., -1., -1., -1., -1., -1., -2., -2., -2., -2.,-2., -3., -3., -3., -3., -3., -4., -4., -4., -4., -4.],[ 0., 0., 0., 0., 0., -1., -1., -1., -1., -1., -2., -2., -2., -2.,-2., -3., -3., -3., -3., -3., -4., -4., -4., -4., -4.],[ 0., 0., 0., 0., 0., -1., -1., -1., -1., -1., -2., -2., -2., -2.,-2., -3., -3., -3., -3., -3., -4., -4., -4., -4., -4.],[ 0., 0., 0., 0., 0., -1., -1., -1., -1., -1., -2., -2., -2., -2.,-2., -3., -3., -3., -3., -3., -4., -4., -4., -4., -4.],[ 0., 0., 0., 0., 0., -1., -1., -1., -1., -1., -2., -2., -2., -2.,-2., -3., -3., -3., -3., -3., -4., -4., -4., -4., -4.],[ 1., 1., 1., 1., 1., 0., 0., 0., 0., 0., -1., -1., -1., -1.,-1., -2., -2., -2., -2., -2., -3., -3., -3., -3., -3.],[ 1., 1., 1., 1., 1., 0., 0., 0., 0., 0., -1., -1., -1., -1.,-1., -2., -2., -2., -2., -2., -3., -3., -3., -3., -3.],[ 1., 1., 1., 1., 1., 0., 0., 0., 0., 0., -1., -1., -1., -1.,-1., -2., -2., -2., -2., -2., -3., -3., -3., -3., -3.],[ 1., 1., 1., 1., 1., 0., 0., 0., 0., 0., -1., -1., -1., -1.,-1., -2., -2., -2., -2., -2., -3., -3., -3., -3., -3.],[ 1., 1., 1., 1., 1., 0., 0., 0., 0., 0., -1., -1., -1., -1.,-1., -2., -2., -2., -2., -2., -3., -3., -3., -3., -3.],[ 2., 2., 2., 2., 2., 1., 1., 1., 1., 1., 0., 0., 0., 0.,0., -1., -1., -1., -1., -1., -2., -2., -2., -2., -2.],[ 2., 2., 2., 2., 2., 1., 1., 1., 1., 1., 0., 0., 0., 0.,0., -1., -1., -1., -1., -1., -2., -2., -2., -2., -2.],[ 2., 2., 2., 2., 2., 1., 1., 1., 1., 1., 0., 0., 0., 0.,0., -1., -1., -1., -1., -1., -2., -2., -2., -2., -2.],[ 2., 2., 2., 2., 2., 1., 1., 1., 1., 1., 0., 0., 0., 0.,0., -1., -1., -1., -1., -1., -2., -2., -2., -2., -2.],[ 2., 2., 2., 2., 2., 1., 1., 1., 1., 1., 0., 0., 0., 0.,0., -1., -1., -1., -1., -1., -2., -2., -2., -2., -2.],[ 3., 3., 3., 3., 3., 2., 2., 2., 2., 2., 1., 1., 1., 1.,1., 0., 0., 0., 0., 0., -1., -1., -1., -1., -1.],[ 3., 3., 3., 3., 3., 2., 2., 2., 2., 2., 1., 1., 1., 1.,1., 0., 0., 0., 0., 0., -1., -1., -1., -1., -1.],[ 3., 3., 3., 3., 3., 2., 2., 2., 2., 2., 1., 1., 1., 1.,1., 0., 0., 0., 0., 0., -1., -1., -1., -1., -1.],[ 3., 3., 3., 3., 3., 2., 2., 2., 2., 2., 1., 1., 1., 1.,1., 0., 0., 0., 0., 0., -1., -1., -1., -1., -1.],[ 3., 3., 3., 3., 3., 2., 2., 2., 2., 2., 1., 1., 1., 1.,1., 0., 0., 0., 0., 0., -1., -1., -1., -1., -1.],[ 4., 4., 4., 4., 4., 3., 3., 3., 3., 3., 2., 2., 2., 2.,2., 1., 1., 1., 1., 1., 0., 0., 0., 0., 0.],[ 4., 4., 4., 4., 4., 3., 3., 3., 3., 3., 2., 2., 2., 2.,2., 1., 1., 1., 1., 1., 0., 0., 0., 0., 0.],[ 4., 4., 4., 4., 4., 3., 3., 3., 3., 3., 2., 2., 2., 2.,2., 1., 1., 1., 1., 1., 0., 0., 0., 0., 0.],[ 4., 4., 4., 4., 4., 3., 3., 3., 3., 3., 2., 2., 2., 2.,2., 1., 1., 1., 1., 1., 0., 0., 0., 0., 0.],[ 4., 4., 4., 4., 4., 3., 3., 3., 3., 3., 2., 2., 2., 2.,2., 1., 1., 1., 1., 1., 0., 0., 0., 0., 0.]])
final_ab=
tensor([[5., 5., 5., 5., 5., 4., 4., 4., 4., 4., 3., 3., 3., 3., 3., 2., 2., 2.,2., 2., 1., 1., 1., 1., 1.],[5., 5., 5., 5., 5., 4., 4., 4., 4., 4., 3., 3., 3., 3., 3., 2., 2., 2.,2., 2., 1., 1., 1., 1., 1.],[5., 5., 5., 5., 5., 4., 4., 4., 4., 4., 3., 3., 3., 3., 3., 2., 2., 2.,2., 2., 1., 1., 1., 1., 1.],[5., 5., 5., 5., 5., 4., 4., 4., 4., 4., 3., 3., 3., 3., 3., 2., 2., 2.,2., 2., 1., 1., 1., 1., 1.],[5., 5., 5., 5., 5., 4., 4., 4., 4., 4., 3., 3., 3., 3., 3., 2., 2., 2.,2., 2., 1., 1., 1., 1., 1.],[6., 6., 6., 6., 6., 5., 5., 5., 5., 5., 4., 4., 4., 4., 4., 3., 3., 3.,3., 3., 2., 2., 2., 2., 2.],[6., 6., 6., 6., 6., 5., 5., 5., 5., 5., 4., 4., 4., 4., 4., 3., 3., 3.,3., 3., 2., 2., 2., 2., 2.],[6., 6., 6., 6., 6., 5., 5., 5., 5., 5., 4., 4., 4., 4., 4., 3., 3., 3.,3., 3., 2., 2., 2., 2., 2.],[6., 6., 6., 6., 6., 5., 5., 5., 5., 5., 4., 4., 4., 4., 4., 3., 3., 3.,3., 3., 2., 2., 2., 2., 2.],[6., 6., 6., 6., 6., 5., 5., 5., 5., 5., 4., 4., 4., 4., 4., 3., 3., 3.,3., 3., 2., 2., 2., 2., 2.],[7., 7., 7., 7., 7., 6., 6., 6., 6., 6., 5., 5., 5., 5., 5., 4., 4., 4.,4., 4., 3., 3., 3., 3., 3.],[7., 7., 7., 7., 7., 6., 6., 6., 6., 6., 5., 5., 5., 5., 5., 4., 4., 4.,4., 4., 3., 3., 3., 3., 3.],[7., 7., 7., 7., 7., 6., 6., 6., 6., 6., 5., 5., 5., 5., 5., 4., 4., 4.,4., 4., 3., 3., 3., 3., 3.],[7., 7., 7., 7., 7., 6., 6., 6., 6., 6., 5., 5., 5., 5., 5., 4., 4., 4.,4., 4., 3., 3., 3., 3., 3.],[7., 7., 7., 7., 7., 6., 6., 6., 6., 6., 5., 5., 5., 5., 5., 4., 4., 4.,4., 4., 3., 3., 3., 3., 3.],[8., 8., 8., 8., 8., 7., 7., 7., 7., 7., 6., 6., 6., 6., 6., 5., 5., 5.,5., 5., 4., 4., 4., 4., 4.],[8., 8., 8., 8., 8., 7., 7., 7., 7., 7., 6., 6., 6., 6., 6., 5., 5., 5.,5., 5., 4., 4., 4., 4., 4.],[8., 8., 8., 8., 8., 7., 7., 7., 7., 7., 6., 6., 6., 6., 6., 5., 5., 5.,5., 5., 4., 4., 4., 4., 4.],[8., 8., 8., 8., 8., 7., 7., 7., 7., 7., 6., 6., 6., 6., 6., 5., 5., 5.,5., 5., 4., 4., 4., 4., 4.],[8., 8., 8., 8., 8., 7., 7., 7., 7., 7., 6., 6., 6., 6., 6., 5., 5., 5.,5., 5., 4., 4., 4., 4., 4.],[9., 9., 9., 9., 9., 8., 8., 8., 8., 8., 7., 7., 7., 7., 7., 6., 6., 6.,6., 6., 5., 5., 5., 5., 5.],[9., 9., 9., 9., 9., 8., 8., 8., 8., 8., 7., 7., 7., 7., 7., 6., 6., 6.,6., 6., 5., 5., 5., 5., 5.],[9., 9., 9., 9., 9., 8., 8., 8., 8., 8., 7., 7., 7., 7., 7., 6., 6., 6.,6., 6., 5., 5., 5., 5., 5.],[9., 9., 9., 9., 9., 8., 8., 8., 8., 8., 7., 7., 7., 7., 7., 6., 6., 6.,6., 6., 5., 5., 5., 5., 5.],[9., 9., 9., 9., 9., 8., 8., 8., 8., 8., 7., 7., 7., 7., 7., 6., 6., 6.,6., 6., 5., 5., 5., 5., 5.]])