window patch按块分割矩阵

ops/2025/2/13 17:55:32/

文章目录

  • 1. excel 示意
  • 2. pytorch代码
  • 3. window mhsa

1. excel 示意

将一个三维矩阵按照window的大小进行拆分成多块2x2窗口矩阵,具体如下图所示
在这里插入图片描述

pytorch_4">2. pytorch代码

import torch
import torch.nn as nn
import torch.nn.functional as Ftorch.set_printoptions(precision=3, sci_mode=False)if __name__ == "__main__":run_code = 0batch_size = 2seq_len = 4model_dim = 6patch_total = batch_size * seq_len * model_dimpatch = torch.arange(patch_total).reshape((batch_size, seq_len, model_dim)).to(torch.float32)print(f"patch.shape=\n{patch.shape}")print(f"patch=\n{patch}")patch_unfold = F.unfold(input=patch, kernel_size=(2, 2), stride=(2, 2))print(f"patch_unfold.shape=\n{patch_unfold.shape}")print(f"patch_unfold=\n{patch_unfold}")#   patch_unfold = patch_unfold.transpose(-1, -2)print(f"patch_unfold=\n{patch_unfold}")patch_nums = patch_unfold.reshape(batch_size, 4, 6)print(f"patch_nums=\n{patch_nums}")patch_nums_new = patch_nums.transpose(-1, -2)print(f"patch_nums_new.shape=\n{patch_nums_new.shape}")print(f"patch_nums_new=\n{patch_nums_new}")patch_nums_final = patch_nums_new.reshape(12, 2, 2)print(f"patch_nums_final.shape=\n{patch_nums_final.shape}")print(f"patch_nums_final=\n{patch_nums_final}")
  • 结果:
patch.shape=
torch.Size([2, 4, 6])
patch=
tensor([[[ 0.,  1.,  2.,  3.,  4.,  5.],[ 6.,  7.,  8.,  9., 10., 11.],[12., 13., 14., 15., 16., 17.],[18., 19., 20., 21., 22., 23.]],[[24., 25., 26., 27., 28., 29.],[30., 31., 32., 33., 34., 35.],[36., 37., 38., 39., 40., 41.],[42., 43., 44., 45., 46., 47.]]])
patch_unfold.shape=
torch.Size([8, 6])
patch_unfold=
tensor([[ 0.,  2.,  4., 12., 14., 16.],[ 1.,  3.,  5., 13., 15., 17.],[ 6.,  8., 10., 18., 20., 22.],[ 7.,  9., 11., 19., 21., 23.],[24., 26., 28., 36., 38., 40.],[25., 27., 29., 37., 39., 41.],[30., 32., 34., 42., 44., 46.],[31., 33., 35., 43., 45., 47.]])
patch_unfold=
tensor([[ 0.,  2.,  4., 12., 14., 16.],[ 1.,  3.,  5., 13., 15., 17.],[ 6.,  8., 10., 18., 20., 22.],[ 7.,  9., 11., 19., 21., 23.],[24., 26., 28., 36., 38., 40.],[25., 27., 29., 37., 39., 41.],[30., 32., 34., 42., 44., 46.],[31., 33., 35., 43., 45., 47.]])
patch_nums=
tensor([[[ 0.,  2.,  4., 12., 14., 16.],[ 1.,  3.,  5., 13., 15., 17.],[ 6.,  8., 10., 18., 20., 22.],[ 7.,  9., 11., 19., 21., 23.]],[[24., 26., 28., 36., 38., 40.],[25., 27., 29., 37., 39., 41.],[30., 32., 34., 42., 44., 46.],[31., 33., 35., 43., 45., 47.]]])
patch_nums_new.shape=
torch.Size([2, 6, 4])
patch_nums_new=
tensor([[[ 0.,  1.,  6.,  7.],[ 2.,  3.,  8.,  9.],[ 4.,  5., 10., 11.],[12., 13., 18., 19.],[14., 15., 20., 21.],[16., 17., 22., 23.]],[[24., 25., 30., 31.],[26., 27., 32., 33.],[28., 29., 34., 35.],[36., 37., 42., 43.],[38., 39., 44., 45.],[40., 41., 46., 47.]]])
patch_nums_final.shape=
torch.Size([12, 2, 2])
patch_nums_final=
tensor([[[ 0.,  1.],[ 6.,  7.]],[[ 2.,  3.],[ 8.,  9.]],[[ 4.,  5.],[10., 11.]],[[12., 13.],[18., 19.]],[[14., 15.],[20., 21.]],[[16., 17.],[22., 23.]],[[24., 25.],[30., 31.]],[[26., 27.],[32., 33.]],[[28., 29.],[34., 35.]],[[36., 37.],[42., 43.]],[[38., 39.],[44., 45.]],[[40., 41.],[46., 47.]]])

3. window mhsa

import torch
import torch.nn as nn
import torch.nn.functional as Ftorch.set_printoptions(precision=3, sci_mode=False)if __name__ == "__main__":run_code = 0bs = 2num_patch = 16patch_depth = 4window_size = 2image_height = image_width = 4num_patch_in_window = window_size * window_sizepatch_total = bs * num_patch * patch_depthpatch_embedding = torch.arange(patch_total).reshape((bs, num_patch, patch_depth)).to(torch.float32)print(f"patch_embedding.shape=\n{patch_embedding.shape}")print(f"patch_embedding=\n{patch_embedding}")patch_embedding = patch_embedding.transpose(-1, -2)patch = patch_embedding.reshape(bs, patch_depth, image_height, image_width)print(f"patch=\n{patch}")window = F.unfold(patch, kernel_size=(window_size, window_size), stride=(window_size, window_size)).transpose(-1,-2)print(f"window.shape=\n{window.shape}")print(f"window=\n{window}")bs, num_window, patch_depth_times_num_patch_in_window = window.shapewindow = window.reshape(bs*num_window,patch_depth,num_patch_in_window).transpose(-1,-2)print(f"window.shape=\n{window.shape}")print(f"window=\n{window}")
  • 结果:
patch_embedding.shape=
torch.Size([2, 16, 4])
patch_embedding=
tensor([[[  0.,   1.,   2.,   3.],[  4.,   5.,   6.,   7.],[  8.,   9.,  10.,  11.],[ 12.,  13.,  14.,  15.],[ 16.,  17.,  18.,  19.],[ 20.,  21.,  22.,  23.],[ 24.,  25.,  26.,  27.],[ 28.,  29.,  30.,  31.],[ 32.,  33.,  34.,  35.],[ 36.,  37.,  38.,  39.],[ 40.,  41.,  42.,  43.],[ 44.,  45.,  46.,  47.],[ 48.,  49.,  50.,  51.],[ 52.,  53.,  54.,  55.],[ 56.,  57.,  58.,  59.],[ 60.,  61.,  62.,  63.]],[[ 64.,  65.,  66.,  67.],[ 68.,  69.,  70.,  71.],[ 72.,  73.,  74.,  75.],[ 76.,  77.,  78.,  79.],[ 80.,  81.,  82.,  83.],[ 84.,  85.,  86.,  87.],[ 88.,  89.,  90.,  91.],[ 92.,  93.,  94.,  95.],[ 96.,  97.,  98.,  99.],[100., 101., 102., 103.],[104., 105., 106., 107.],[108., 109., 110., 111.],[112., 113., 114., 115.],[116., 117., 118., 119.],[120., 121., 122., 123.],[124., 125., 126., 127.]]])
patch=
tensor([[[[  0.,   4.,   8.,  12.],[ 16.,  20.,  24.,  28.],[ 32.,  36.,  40.,  44.],[ 48.,  52.,  56.,  60.]],[[  1.,   5.,   9.,  13.],[ 17.,  21.,  25.,  29.],[ 33.,  37.,  41.,  45.],[ 49.,  53.,  57.,  61.]],[[  2.,   6.,  10.,  14.],[ 18.,  22.,  26.,  30.],[ 34.,  38.,  42.,  46.],[ 50.,  54.,  58.,  62.]],[[  3.,   7.,  11.,  15.],[ 19.,  23.,  27.,  31.],[ 35.,  39.,  43.,  47.],[ 51.,  55.,  59.,  63.]]],[[[ 64.,  68.,  72.,  76.],[ 80.,  84.,  88.,  92.],[ 96., 100., 104., 108.],[112., 116., 120., 124.]],[[ 65.,  69.,  73.,  77.],[ 81.,  85.,  89.,  93.],[ 97., 101., 105., 109.],[113., 117., 121., 125.]],[[ 66.,  70.,  74.,  78.],[ 82.,  86.,  90.,  94.],[ 98., 102., 106., 110.],[114., 118., 122., 126.]],[[ 67.,  71.,  75.,  79.],[ 83.,  87.,  91.,  95.],[ 99., 103., 107., 111.],[115., 119., 123., 127.]]]])
window.shape=
torch.Size([2, 4, 16])
window=
tensor([[[  0.,   4.,  16.,  20.,   1.,   5.,  17.,  21.,   2.,   6.,  18.,22.,   3.,   7.,  19.,  23.],[  8.,  12.,  24.,  28.,   9.,  13.,  25.,  29.,  10.,  14.,  26.,30.,  11.,  15.,  27.,  31.],[ 32.,  36.,  48.,  52.,  33.,  37.,  49.,  53.,  34.,  38.,  50.,54.,  35.,  39.,  51.,  55.],[ 40.,  44.,  56.,  60.,  41.,  45.,  57.,  61.,  42.,  46.,  58.,62.,  43.,  47.,  59.,  63.]],[[ 64.,  68.,  80.,  84.,  65.,  69.,  81.,  85.,  66.,  70.,  82.,86.,  67.,  71.,  83.,  87.],[ 72.,  76.,  88.,  92.,  73.,  77.,  89.,  93.,  74.,  78.,  90.,94.,  75.,  79.,  91.,  95.],[ 96., 100., 112., 116.,  97., 101., 113., 117.,  98., 102., 114.,118.,  99., 103., 115., 119.],[104., 108., 120., 124., 105., 109., 121., 125., 106., 110., 122.,126., 107., 111., 123., 127.]]])
window.shape=
torch.Size([8, 4, 4])
window=
tensor([[[  0.,   1.,   2.,   3.],[  4.,   5.,   6.,   7.],[ 16.,  17.,  18.,  19.],[ 20.,  21.,  22.,  23.]],[[  8.,   9.,  10.,  11.],[ 12.,  13.,  14.,  15.],[ 24.,  25.,  26.,  27.],[ 28.,  29.,  30.,  31.]],[[ 32.,  33.,  34.,  35.],[ 36.,  37.,  38.,  39.],[ 48.,  49.,  50.,  51.],[ 52.,  53.,  54.,  55.]],[[ 40.,  41.,  42.,  43.],[ 44.,  45.,  46.,  47.],[ 56.,  57.,  58.,  59.],[ 60.,  61.,  62.,  63.]],[[ 64.,  65.,  66.,  67.],[ 68.,  69.,  70.,  71.],[ 80.,  81.,  82.,  83.],[ 84.,  85.,  86.,  87.]],[[ 72.,  73.,  74.,  75.],[ 76.,  77.,  78.,  79.],[ 88.,  89.,  90.,  91.],[ 92.,  93.,  94.,  95.]],[[ 96.,  97.,  98.,  99.],[100., 101., 102., 103.],[112., 113., 114., 115.],[116., 117., 118., 119.]],[[104., 105., 106., 107.],[108., 109., 110., 111.],[120., 121., 122., 123.],[124., 125., 126., 127.]]])

http://www.ppmy.cn/ops/158095.html

相关文章

分布式 IO 模块:港口控制主柜的智能 “助手”

在繁忙的港口,每一个集装箱的装卸、每一艘货轮的停靠与离港,都离不开高效精准的控制系统。港口控制主柜作为整个港口作业的核心枢纽之一,其稳定运行至关重要。而明达技术自主研发推出的MR30分布式 IO 模块可作为从站,与 PLC&#…

redis之事件

文章目录 文件事件文件事件处理器的构成多路复用程序的实现事件的类型文件事件的处理器 时间事件实现时间事件应用实例:ServerCron函数 事件的调度与执行总结 Redis服务器是一个事件驱动程序,服务器需要处理以下两类事件: 文件事件&#xff0…

【GeeRPC】Day5:支持 HTTP 协议

Day5:支持 HTTP 协议 今天要完成的任务如下: 支持 HTTP 协议;基于 HTTP 实现一个简单的 Debug 页面,代码约 150 行; 支持 HTTP 协议需要什么? Web 开发中,我们常使用 HTTP 协议中的 HEAD、G…

129,【2】buuctf [BJDCTF2020]EzPHP

进入靶场 查看源代码 看到红框就知道对了 她下面那句话是编码后的&#xff0c;解码 1nD3x.php <?php // 高亮显示当前 PHP 文件的源代码&#xff0c;通常用于调试和展示代码结构 highlight_file(__FILE__); // 设置错误报告级别为 0&#xff0c;即不显示任何 PHP 错误信息…

介绍下SpringBoot如何处理大数据量业务

Spring Boot 处理大数据量业务时&#xff0c;通常会面临性能、内存、数据库负载等挑战。为了高效处理大数据量&#xff0c;Spring Boot 提供了多种解决方案和优化策略。以下是一些常见的处理方式&#xff1a; 1. 分页查询 问题&#xff1a;一次性查询大量数据会导致内存溢出和…

Go语言开发桌面应用基础框架(wails v3)-开箱即用框架

前言 本文是介绍如何集成好了Wails3开发框架以及提供视频教程&#xff0c;当你需要桌面开发时&#xff0c;直接下载我们基础框架代码&#xff0c;开箱即用不用配置开发需要依赖。 为什么使用v3版本&#xff0c;主要是v3新增的功能 ​支持多个窗口&#xff1a;在单个应用程序…

服务器,交换机和路由器的一些笔记

服务器、交换机和路由器是网络中常用的设备&#xff0c;它们的本质区别和联系如下&#xff1a; 本质区别 功能不同 服务器&#xff1a;就像一个大型的资料仓库和工作处理中心&#xff0c;主要用来存储和管理各种数据&#xff0c;比如网站的网页数据、公司的办公文档等&#x…

RDKit 给3D信息缺失的sdf生成三维结构

要生成包含三维结构的 SDF 文件&#xff0c;可以使用 RDKit 等化学信息学工具。以下是一个 Python 脚本示例&#xff0c;使用 RDKit 读取 SDF 文件、生成三维结构并保存。 ### 安装 RDKit 如果尚未安装 RDKit&#xff0c;可以通过以下命令安装&#xff1a; bash conda instal…