Tensor轴变换 axis 或 dim(transpose、permute、view、reshape、einsum)

news/2025/3/23 22:51:41/

操作分类:

  1. 重排维度transpose、swapaxes、permute都是对维度进行重排序,但不改变维度的大小。

  2. 重组维度view、reshape可以重组原始维度,修改维度大小。

  3. 万能运算einsum 通过操作index(dim/axis)匹配对应的矩阵运算

    • dim 与 axis
    • transpose 重排维度
    • permute 重排维度
    • view 重组维度
    • reshape 重组维度
    • einsum 万能运算

dim 与 axis

Tensor的 dim维度axis轴 变换 是 Pytorch深度学习最重要的操作之一(在torch中叫dim多一些,在numpy中叫axis多一些),这些操作不改变内存中的物理存储,只会改变tensor的视图view,即以什么样的顺序或维度来看待这个tensor,越靠后的维度在内存上越相连每个维度都有具体的物理含义。可以通过tensor.shape来查看一个张量的维度。

如加载图像数据后,[32, 3, 64,64]可以理解为[batch_size, channel, hight, weight],如self-attention中[16, 8, 32, 128]可以理解为[batch_szie, heads, seq_len, head_dim]

tensor的dim索引从下标0开始,如shape为[10, 3, 64, 64]的tensor,其dim的取值范围是0,1,2,3

如下例子:

import torch
tensor = torch.randn(10, 3, 64, 64).to("cuda")
tensor.shape  # torch.Size([10, 3, 64, 64])
  • tensor[i]等价于tensor[i, :, :]tensor[i]的shape为[3, 64, 64];
  • tensor[i, j]等价于tensor[i, j, :]tensor[i, j]的shape为[64, 64].

transpose 重排维度

  • 使用方法torch.tanspose(tensor, dim1, dim2)交换 tensor 的 dim1 和 dim2 这两个维度
import torch
tensor = torch.randn(16, 8, 32, 128).to("cuda")
# torch.Size([16, 8, 32, 128])
trans = torch.transpose(tensor, 2, 3).contiguous()
# torch.Size([16, 8, 128, 32])

另外,swapaxes就是tanspose的别名!torch.swapaxes(tensor, dim1, dim2),效果等于上面的tanspose。

permute 重排维度

  • 使用方法transpose和swapaxes只能交换两个维度dim,而permute可以对所有轴进行重排torch.permute(dim1, dim2, dim3...)dim_i是原始维度的索引,将其放到新的位置,就是交换旧维度到新索引位置。
import torch
tensor = torch.randn(16, 8, 32, 128).to("cuda")
# torch.Size([16, 8, 32, 128])
tensor = tensor.permute(0, 2, 1, 3)  # 交换1,2维度
# torch.Size([16, 32, 8, 128])

view 重组维度

  • 使用方法tensor.contiguous().view(dim0, dim1, dim2...) ,将tensor的shape变换为(dim0, dim1, dim2...)dim的个数可以少于或多于原来tensor!,因为所有维度的累积 ∏ i = 0 N d i m i \prod_{i=0}^N{dim_i} i=0Ndimi是不变的,因此当有一个dim=-1时,将自动计算。
import torch
tensor = torch.randn(16, 8, 32, 16).to("cuda")
# torch.Size([16, 8, 32, 16])  (batch_size, heads, seq_len, head_dim)
tensor = tensor.contiguous().view(16, 32, -1)  # 合头heads
# torch.Size([16, 32, 128])  (batch_size, seq_len, dim)
  • contiguous:因为transpose和permute这些操作不改变内存中的物理存储,而torch要求 越靠后的维度在内存上越相连,所以按照新维度索引,tensor在内存中不再是连续存储的,但view操作要求tensor的内存连续存储,需要用tensor.contiguous() 将原始的tensor调整为一个内存连续的tensor。在pytorch 0.4中,增加了torch.reshape()操作,大致相当于 tensor.contiguous().view(),这样就省去了对tensor做view()变换前,调用contiguous()的麻烦;因此建议所有情况都无脑使用 reshape

reshape 重组维度

  • 使用方法tonsor.reshape()tensor.contiguous().view()tensor.reshape(dim0, dim1, dim2...) ,将tensor的shape变换为(dim0, dim1, dim2...)dim的个数可以少于或多于原来tensor!,因为所有维度的累积 ∏ i = 0 N d i m i \prod_{i=0}^N{dim_i} i=0Ndimi是不变的,因此当有一个dim=-1时,将自动计算。
import torch
tensor = torch.randn(16, 8, 32, 16).to("cuda")
# torch.Size([16, 8, 32, 16])  (batch_size, heads, seq_len, head_dim)
tensor = tensor.reshape(16, 32, -1)  # 合头heads
# torch.Size([16, 32, 128])  (batch_size, seq_len, dim)

einsum 万能运算

  • 使用方法:爱因斯坦表达式通过操作index(dim/axis)匹配对应的矩阵运算。和前面几个操作不同的是,torch.einsum不仅可以进行单个矩阵维度的重排、重组,还可以完成多个矩阵的矩阵加法矩阵乘法元素乘法等运算

->左侧表示输入的矩阵shape,->右侧表示输出的矩阵shape

  • permute 重排:单个输入矩阵,->左右维度数量不变,只改变顺序,如交换i和j维度,ij->ji
import torch
tensor = torch.randn(16, 8, 32, 16).to("cuda")
# torch.Size([16, 8, 32, 16])
tensor = torch.einsum("bhsd->bhds", tensor)
# torch.Size([16, 8, 16, 32])
  • sum求和:单个输入矩阵,->右侧缺少哪些维度,就按照哪些维度求和,如按照j维度求和,ij->i
import torch
tensor = torch.randn(16, 8, 32, 16).to("cuda")
# torch.Size([16, 8, 32, 16])
tensor = torch.einsum("bhsd->bh", tensor)
# torch.Size([16, 8])
  • matrix multi 矩阵乘法->左边多个输入矩阵逗号分隔,->左边是单个矩阵,沿左边两者重复出现右边消失的维度进行乘法,如沿k维度进行矩阵乘法,ij,jk->ik
tensor1 = torch.randn(2, 3).to("cuda")
tensor2 = torch.randn(3, 5).to("cuda")
tensor = torch.einsum("ij, jk -> ik", tensor1, tensor2)
# (2,3) @ (3,5) = (2,5)

组合操作:先沿着j维度进行矩阵乘法,再沿着k维度进行求和:

tensor1 = torch.randn(2, 3).to("cuda")
tensor2 = torch.randn(3, 5).to("cuda")
tensor = torch.einsum("ij, jk -> i", tensor1, tensor2)
# (2,3) @ (3,5) = (2,5)

更加复杂的组合操作:模拟attention score,先自动进行转置,然后最后两个维度进行矩阵乘法,其中虽然都有seq_len,但因为output输出矩阵中不能出现两个相同的字母,所以不能都用s命名,因此使用i和j

import torch
# key 和 value 都是[batch_size, heads, seq_len, head_dim]
query = torch.randn(16, 8, 32, 16).to("cuda")
key = torch.randn(16, 8, 32, 16).to("cuda")
attention_score = torch.einsum("bhid, bhjd -> bhij", query, key)  # bhid, bhjd -> bhid, bhdj -> bhij
# torch.Size([16, 8, 32, 32])# 等价操作
attention_score = query @ key.transpose(-2, -1)
attention_score = torch.matmul(query, key.transpose(-2, -1))
  • element-wise multi 元素乘法->左边多个相同shape的矩阵,->右边单个和做左边相同shape的矩阵。矩阵对应元素相乘,也叫hadamard product
import torch
tensor1 = torch.randn(16, 8, 32, 16).to("cuda")
tensor2 = torch.randn(16, 8, 32, 16).to("cuda")tensor = torch.einsum("bhsd,bhsd->bhsd", tensor1, tensor2)
# torch.Size([16, 8, 32, 16])# 等价操作
tensor = tensor1 * tensor2
  • dot product 矩阵点积->左边多个相同shape的矩阵,->是空的(求和sum)。即,先逐元素相乘,然后全部求和
import torch
tensor1 = torch.randn(16, 8, 32, 16).to("cuda")
tensor2 = torch.randn(16, 8, 32, 16).to("cuda")tensor = torch.einsum("bhsd,bhsd-> ", tensor1, tensor2)
# tensor是一个值# 等价操作
tensor = sum(tensor1 * tensor2)

http://www.ppmy.cn/news/1291551.html

相关文章

Oracle/Myql批量操作

前言&#xff1a;在oracle中使用insert into values (),(),()多种方式都不能成功,记录正确的批量方法 注意&#xff1a;oracle有自己实现批量的方法&#xff0c;mysql适用的&#xff0c;oracle不一定适用 <insert id"insertTaskImportOpen" parameterType"l…

使用 JWT(JSON Web 令牌)实现登录身份验证和令牌续订

文档链接 文档链接, PDF中包含一部分宣传大字制作不易还望多多支持互相交流 使用 JWT&#xff08;JSON Web 令牌&#xff09;实现登录身份验证和令牌续订。它将 JWT 与基于会话的身份验证进行了比较&#xff0c;并强调了每种方法的差异、优点和缺点。本文档介绍了基于会话的方…

Dockerfile学习文档

Dockerfile详解 Dockerfile是一个组合映像命令的文本&#xff1b;可以使用在命令行中调用任何命令&#xff1b;Docker通过dockerfile中的指令自动生成镜像。 通过docker build -t repository:tag ./ 即可构建&#xff0c;要求&#xff1a;./下存在Dockerfile文件 之前我们聊的…

elasticsearch系列七:聚合查询

概述 今天咱们来看下es中的聚合查询&#xff0c;在es中聚合查询分为三大类bucket、metrics、pipeline&#xff0c;每一大类下又有十几种小类&#xff0c;咱们各举例集中&#xff0c;有兴许的同学可以参考官网&#xff1a;https://www.elastic.co/guide/en/elasticsearch/refere…

(NeRF学习)NeRF复现 win11

目录 一、获取源码二、环境三、准备数据集方法一&#xff1a;官方命令方法二&#xff1a;官网下载数据集 四、开始训练1.更改迭代次数2.开始训练方法一&#xff1a;方法二&#xff1a; 3.使用预训练模型 五、NeRF源码学习 一、获取源码 git clone https://github.com/bmild/ne…

MySQL中的事务到底是怎么一回事儿

简单来说&#xff0c;事务就是要保证一组数据库操作&#xff0c;要么全部成功&#xff0c;要么全部失败。在MySQL中&#xff0c;事务支持是在引擎层实现的&#xff0c;但并不是所有的引擎都支持事务&#xff0c;如MyISAM引擎就不支持事务&#xff0c;这也是MyISAM被InnoDB取代的…

931. 下降路径最小和-Python-DP-简单题

Problem: 931. 下降路径最小和 文章目录 思路解题方法复杂度Code 思路 看了一些题解&#xff0c;感觉写的很复杂&#xff0c;其实我的思考很简单&#xff0c;直接在原数组进行修改 解题方法 第一行不变&#xff0c;从第二行开始&#xff0c;能到达当前位置的路径最多只有三条&a…

第十四章 14.2案例:使用KVM命令集管理虚拟机

查看命令帮助 [rootLinux01 ~]# virsh -h—————————————————————————————————————————— 查看KVM的配置文件存放目录〈test01 , xml是虚拟机系统实例的配置文件) [rootLinux01 ~]# ls /etc/libvirt/qemu —————————————…