整理张量(tensor)中常用的操作

news/2024/10/17 1:27:46/

维度变换

1、创建张量

torch.zeros(*sizes, out=None, …)# 返回大小为sizes的零矩阵

torch.ones(*sizes, out=None, …) #f返回大小为sizes的单位矩阵

2、张量在某个维度上的的加减乘除

torch.sum() 用于计算张量在某个维度上的元素之和,返回一个新的张量

减:torch.sub()

乘:torch.mul()

除:torch.div()

3、维度变换

【维度变形】view() 、 reshape() 、einops.rearrange()

★推荐使用reshape() 、einops.rearrange()

einops.rearrange、repeat、reduce==>对维度进行操作_马鹏森的博客-CSDN博客

从功能上来看,它们的作用是相同的,都是用来重塑 Tensor 的 shape 的。

view 只适合对满足连续性条件 (contiguous) 的 Tensor进行操作,而reshape 同时还可以对不满足连续性条件的 Tensor 进行操作,具有更好的鲁棒性。view 能干的 reshape都能干,如果 view 不能干就可以用 reshape 来处理。

【维度增加】unsqueeze()、torch.expand() 、einops.repeat() 【维度减少】squeeze() 、einops.reduce()

★不推荐使用:torch.expand() 、

  • squeeze():对 tensor 进行维度的压缩,去掉维数为 1 的维度。用法:torch.squeeze(a) 将 a 中所有为 1 的维度都删除,或者 a.squeeze(1) 是去掉 a中指定的维数为 1 的维度。
  • unsqueeze():对数据维度进行扩充,给指定位置加上维数为 1 的维度。用法:torch.unsqueeze(a, N),或者 a.unsqueeze(N),在 a 中指定位置 N 加上一个维数为 1 的维度。

【pytorch】torch.unsqueeze() 和 torch.squeeze()==>扩充维度和降维_torch扩展维度_马鹏森的博客-CSDN博客

expand() 函数只能将size=1的维度扩展到更大的尺寸,如果扩展其他维度会报错。 

【维度之间交换】transpose()、 permute()、einops.rearrange()

★推荐使用:permute()、einops.rearrange()

torch.transpose() 只能交换两个维度,而 .permute() 可以自由交换任意位置。函数定义如下:

transpose(dim0, dim1) → Tensor  # See torch.transpose()
permute(*dims) → Tensor  # dim(int). Returns a view of the original tensor with its dimensions permuted.

在 CNN 模型中,我们经常遇到交换维度的问题,

举例:四个维度表示的 tensor:[batch, channel, h, w]nchw),如果想把 channel 放到最后去,形成[batch, h, w, channel]nhwc),如果使用 torch.transpose() 方法,至少要交换两次(先 1 3 交换再 1 2 交换),而使用 .permute() 方法只需一次操作,更加方便。例子程序如下:

import torch
input = torch.rand(1,3,28,32)                    # torch.Size([1, 3, 28, 32]
print(b.transpose(1, 3).shape)                   # torch.Size([1, 32, 28, 3])
print(b.transpose(1, 3).transpose(1, 2).shape)   # torch.Size([1, 28, 32, 3])print(b.permute(0,2,3,1).shape)                  # torch.Size([1, 28, 32, 3]

三、合并分割

torch.flatten()torch.cat()两个函数具有截然不同的功能。flatten()重点在于将一个高维的复杂输入变成一维的线性向量,而cat()重点在于将多个张量在某个维度上拼接为一个更大的张量。

将任意维度张量转为一维张量

torch.flatten()函数_马鹏森的博客-CSDN博客

torch.flatten(input, start_dim=0, end_dim=-1)

【维度拼接】torch.cat 和 torch.stack

【PyTorch】torch.cat==>张量拼接,在图像的应用上可以有效利用原始图像结构信息_特征图进行cat_马鹏森的博客-CSDN博客

np.stack()函数详解 ==>堆叠 【类似于torch.stack()】_马鹏森的博客-CSDN博客

可以用 torch.cat 方法和 torch.stack 方法将多个张量合并,torch.cat 和 torch.stack 有略微的区别,

torch.cat 是连接,不会增加一个新的维度,而 torch.stack 是堆叠,会增加一个新的维度。两者函数定义如下:

# Concatenates the given sequence of seq tensors in the given dimension. All tensors must either have the same shape (except in the concatenating dimension) or be empty.
torch.cat(tensors, dim=0, *, out=None) → Tensor
# Concatenates a sequence of tensors along **a new** dimension. All tensors need to be of the same size.
torch.stack(tensors, dim=0, *, out=None) → Tensor

【维度分割】torch.split 和 torch.chunk

torch.split() 和 torch.chunk() 【在给定维度(轴)上将输入张量进行分块】可以看作是 torch.cat() 的逆运算split() 作用是将张量拆分为多个块,每个块都是原始张量的视图。split() 函数定义如下:

"""
Splits the tensor into chunks. Each chunk is a view of the original tensor.
If split_size_or_sections is an integer type, then tensor will be split into equally sized chunks (if possible). Last chunk will be smaller if the tensor size along the given dimension dim is not divisible by split_size.
If split_size_or_sections is a list, then tensor will be split into len(split_size_or_sections) chunks with sizes in dim according to split_size_or_sections.
"""
torch.split(tensor, split_size_or_sections, dim=0)

chunk() 作用是将 tensor 按 dim(行或列)分割成 chunks 个 tensor 块,返回的是一个元组。chunk() 函数定义如下:

torch.chunk(input, chunks, dim=0) → List of Tensors
"""
Splits a tensor into a specific number of chunks. Each chunk is a view of the input tensor.
Last chunk will be smaller if the tensor size along the given dimension dim is not divisible by chunks.
Parameters:input (Tensor) – the tensor to splitchunks (int) – number of chunks to returndim (int) – dimension along which to split the tensor
"""


http://www.ppmy.cn/news/59440.html

相关文章

Linux 动态/静态配置ip网卡信息

配置网卡 第一步:查看网卡信息 查看网卡信息 在配置网卡之前,首先需要查看网卡信息。以下是在不同Linux发行版上查看网卡信息的方法: 方法一:使用ifconfig命令 输入ifconfig命令查看网卡信息。此命令适用于大多数Linux发行版&…

浙大数据结构第三周初识二叉树

03-树1 树的同构 (25分) 给定两棵树T1和T2。如果T1可以通过若干次左右孩子互换就变成T2,则我们称两棵树是“同构”的。例如图1给出的两棵树就是同构的,因为我们把其中一棵树的结点A、B、G的左右孩子互换后,就得到另外一棵树。而图2就不是同构…

音乐游戏《Tiles Hop》核心功能

文章目录 一、 介绍二、 进入游戏三、 初始化四、 游戏管理器五、 控制小球六、 音乐节奏编码 一、 介绍 音乐游戏《Tiles Hop》,随着音乐节奏进行跳跃 球在一定的速度下,特定的时候踩到砖块,同时正好和音乐的节奏要配合上; LRC歌词编辑器:…

对折纸张厚度超过珠峰

对折 0.1 毫米的纸张,循环对折,超过珠峰高度输出对折次数。 【学习的细节是欢悦的历程】 Python 官网:https://www.python.org/ Free:大咖免费“圣经”教程《 python 完全自学教程》,不仅仅是基础那么简单…… 地址&a…

HTTPS建立连接原理、SSL工作原理

HTTPS与HTTP相比有什么区别? HTTPS保证安全的原理是什么? HTTPS是如何建立连接的? 巨人的肩膀 3.1 HTTP 常见面试题 | 小林coding HTTP与HTTPS的区别 HTTP是超文本传输协议,传输的内容是明文(HTTP1.1及之前版本)。HTTPS在TCP与HT…

第23章 MongoDB 复制(副本集)教程

第23章 MongoDB 复制(副本集)教程 MongoDB复制是将数据同步在多个server 的过程。 复制提供了数据的冗余备份,并在多个server 上存储数据副本,提高了数据的可用性, 并可以保证数据的安全性。 复制还允许尊敬的读者从…

浅拷贝和深拷贝

浅拷贝: 定义:浅拷贝(Shallow Copy)是一种简单的对象复制方式,将一个对象的数据成员直接复制给另一个对象(通常是通过默认的复制构造函数或赋值运算符实现),这些数据成员可以是基本…

每日一题142——最少操作使数组递增

给你一个整数数组 nums (下标从 0 开始)。每一次操作中,你可以选择数组中一个元素,并将它增加 1 。 比方说,如果 nums [1,2,3] ,你可以选择增加 nums[1] 得到 nums [1,3,3] 。 请你返回使 nums 严格递增…