kron积计算mask类别矩阵

devtools/2025/2/19 17:40:53/

文章目录

1. 生成类别矩阵如下

在这里插入图片描述

pytorch__3">2. pytorch 代码

python">import torch
import torch.nn as nn
import torch.nn.functional as Ftorch.set_printoptions(precision=3, sci_mode=False)if __name__ == "__main__":run_code = 0a_matrix = torch.arange(4).reshape(2, 2) + 1b_matrix = torch.ones((2, 2))print(f"a_matrix=\n{a_matrix}")print(f"b_matrix=\n{b_matrix}")c_matrix = torch.kron(input=a_matrix, other=b_matrix)print(f"c_matrix=\n{c_matrix}")d_matrix = torch.arange(9).reshape(3, 3) + 1e_matrix = torch.ones((2, 2))f_matrix = torch.kron(input=d_matrix, other=e_matrix)print(f"d_matrix=\n{d_matrix}")print(f"e_matrix=\n{e_matrix}")print(f"f_matrix=\n{f_matrix}")g_matrix = f_matrix[1:-1, 1:-1]print(f"g_matrix=\n{g_matrix}")
  • 结果:
python">a_matrix=
tensor([[1, 2],[3, 4]])
b_matrix=
tensor([[1., 1.],[1., 1.]])
c_matrix=
tensor([[1., 1., 2., 2.],[1., 1., 2., 2.],[3., 3., 4., 4.],[3., 3., 4., 4.]])
d_matrix=
tensor([[1, 2, 3],[4, 5, 6],[7, 8, 9]])
e_matrix=
tensor([[1., 1.],[1., 1.]])
f_matrix=
tensor([[1., 1., 2., 2., 3., 3.],[1., 1., 2., 2., 3., 3.],[4., 4., 5., 5., 6., 6.],[4., 4., 5., 5., 6., 6.],[7., 7., 8., 8., 9., 9.],[7., 7., 8., 8., 9., 9.]])
g_matrix=
tensor([[1., 2., 2., 3.],[4., 5., 5., 6.],[4., 5., 5., 6.],[7., 8., 8., 9.]])

3. 循环移动矩阵

python">import torch
import torch.nn as nn
import torch.nn.functional as F
import mathtorch.set_printoptions(precision=3, sci_mode=False)class WindowMatrix(object):def __init__(self, num_patch=4, size=2):self.num_patch = num_patchself.size = sizeself.width = self.num_patchself.height = self.size * self.sizeself._result = torch.zeros((self.width, self.height))@propertydef result(self):a_size = int(math.sqrt(self.num_patch))a_matrix = torch.arange(self.num_patch).reshape(a_size, a_size) + 1b_matrix = torch.ones(self.size, self.size)self._result = torch.kron(input=a_matrix, other=b_matrix)return self._resultclass ShiftedWindowMatrix(object):def __init__(self, num_patch=9, size=2):self.num_patch = num_patchself.size = sizeself.width = self.num_patchself.height = self.size * self.sizeself._result = torch.zeros((self.width, self.height))@propertydef result(self):a_size = int(math.sqrt(self.num_patch))a_matrix = torch.arange(self.num_patch).reshape(a_size, a_size) + 1b_matrix = torch.ones(self.size, self.size)my_result = torch.kron(input=a_matrix, other=b_matrix)self._result = my_result[1:-1, 1:-1]return self._resultclass RollShiftedWindowMatrix(object):def __init__(self, num_patch=9, size=2):self.num_patch = num_patchself.size = sizeself.width = self.num_patchself.height = self.size * self.sizeself._result = torch.zeros((self.width, self.height))@propertydef result(self):a_size = int(math.sqrt(self.num_patch))a_matrix = torch.arange(self.num_patch).reshape(a_size, a_size) + 1b_matrix = torch.ones(self.size, self.size)my_result = torch.kron(input=a_matrix, other=b_matrix)my_result = my_result[1:-1, 1:-1]roll_result = torch.roll(input=my_result, shifts=(-1, -1), dims=(-1, -2))self._result = roll_resultreturn self._resultclass BackRollShiftedWindowMatrix(object):def __init__(self, num_patch=9, size=2):self.num_patch = num_patchself.size = sizeself.width = self.num_patchself.height = self.size * self.sizeself._result = torch.zeros((self.width, self.height))@propertydef result(self):a_size = int(math.sqrt(self.num_patch))a_matrix = torch.arange(self.num_patch).reshape(a_size, a_size) + 1b_matrix = torch.ones(self.size, self.size)my_result = torch.kron(input=a_matrix, other=b_matrix)my_result = my_result[1:-1, 1:-1]roll_result = torch.roll(input=my_result, shifts=(-1, -1), dims=(-1, -2))print(f"roll_result=\n{roll_result}")roll_result = torch.roll(input=roll_result, shifts=(1, 1), dims=(-1, -2))self._result = roll_resultreturn self._resultif __name__ == "__main__":run_code = 0my_window_matrix = WindowMatrix()my_window_matrix_result = my_window_matrix.resultprint(f"my_window_matrix_result=\n{my_window_matrix_result}")shifted_window_matrix = ShiftedWindowMatrix()shifed_window_matrix_result = shifted_window_matrix.resultprint(f"shifed_window_matrix_result=\n{shifed_window_matrix_result}")roll_shifted_window_matrix = RollShiftedWindowMatrix()roll_shifed_window_matrix_result = roll_shifted_window_matrix.resultprint(f"roll_shifed_window_matrix_result=\n{roll_shifed_window_matrix_result}")Back_roll_shifted_window_matrix = BackRollShiftedWindowMatrix()back_roll_shifed_window_matrix_result = Back_roll_shifted_window_matrix.resultprint(f"back_roll_shifed_window_matrix_result=\n{back_roll_shifed_window_matrix_result}")
  • 结果:
python">my_window_matrix_result=
tensor([[1., 1., 2., 2.],[1., 1., 2., 2.],[3., 3., 4., 4.],[3., 3., 4., 4.]])
shifed_window_matrix_result=
tensor([[1., 2., 2., 3.],[4., 5., 5., 6.],[4., 5., 5., 6.],[7., 8., 8., 9.]])
roll_shifed_window_matrix_result=
tensor([[5., 5., 6., 4.],[5., 5., 6., 4.],[8., 8., 9., 7.],[2., 2., 3., 1.]])
roll_result=
tensor([[5., 5., 6., 4.],[5., 5., 6., 4.],[8., 8., 9., 7.],[2., 2., 3., 1.]])
back_roll_shifed_window_matrix_result=
tensor([[1., 2., 2., 3.],[4., 5., 5., 6.],[4., 5., 5., 6.],[7., 8., 8., 9.]])

http://www.ppmy.cn/devtools/159263.html

相关文章

rustdesk远程桌面自建服务器

首先,我这里用到的是阿里云服务器 centos7版本,win版客户端。 准备工作 centos7 服务器端文件: https://github.com/rustdesk/rustdesk-server/releases/download/1.1.11-1/rustdesk-server-linux-amd64.zip win版客户端安装包&#xff1…

P10452 货仓选址

链接:P10452 货仓选址 - 洛谷 题目描述 在一条数轴上有 N 家商店,它们的坐标分别为 A1​∼AN​。 现在需要在数轴上建立一家货仓,每天清晨,从货仓到每家商店都要运送一车商品。 为了提高效率,求把货仓建在何处&…

游戏引擎学习第101天

回顾当前情况 昨天的进度基本上完成了所有内容,但我们还没有进行调试。虽然我们在运行时做的事情大致上是对的,但还是存在一些可能或者确定的bug。正如昨天最后提到的,既然现在时间晚了,就不太适合开始调试,所以今天我…

C# windowForms 的DataGridView控件的使用

C# Windows Forms DataGridView 控件使用详解 DataGridView 是 Windows Forms 中用于显示和编辑表格数据的核心控件。它支持高度自定义的列类型、数据绑定、事件处理和丰富的样式配置。以下是其详细使用方法。 目录 基础使用 数据绑定 列类型与自定义

PostgreSQL技术内幕25:时序数据库插件TimescaleDB

文章目录 0.简介1.基础知识1.1 背景1.2 概念1.3 特点 2.TimescaleDB2.1 安装使用2.1 文件结构2.2 原理2.2.1 整体结构2.2.2 超表2.2.3 自动分区2.2.4 数据写入与查询优化2.2.5 数据保留策略2.2.6 更多特性 0.简介 现今时序数据库的应用场景十分广泛,其通过保留时间…

[矩形绘制]

矩形绘制 真题目录: 点击去查看 E 卷 200分题型 题目描述 实现一个简单的绘图模块,绘图模块仅支持矩形的绘制和擦除 当新绘制的矩形与之前的图形重叠时,对图形取并集当新擦除的矩形与之前的图形重叠时,对图形取差集给定一系列矩形的绘制和擦除操作,计算最终图形的面积。 …

LDR6500 PD芯片:智能充电与数据传输

LDR6500 PD芯片:引领智能充电与数据传输新时代 随着科技的飞速发展,电子设备的充电与数据传输需求愈发迫切。为满足这一需求,PD(Power Delivery)芯片应运而生,其中LDR6500以其卓越的性能和广泛的应用前景&…

shell命令脚本(2)——条件语句

个人博客站—运维鹿: http://www.kervin24.top CSDN博客—做个超努力的小奚: https://blog.csdn.net/qq_52914969?typeblog1、条件测试(上) 1.1.1、字符串比较 基本语法:判断成功为0,不成功为1 是否为空[ -z “字符串” ]是否…