维度变换
1、创建张量
torch.zeros(*sizes, out=None, …)# 返回大小为sizes的零矩阵
torch.ones(*sizes, out=None, …) #f返回大小为sizes的单位矩阵
2、张量在某个维度上的的加减乘除
加:
torch.sum()
用于计算张量在某个维度上的元素之和,返回一个新的张量
减:torch.sub()
乘:torch.mul()
除:torch.div()
3、维度变换
【维度变形】view() 、 reshape() 、einops.rearrange()
★推荐使用:reshape() 、einops.rearrange()
einops.rearrange、repeat、reduce==>对维度进行操作_马鹏森的博客-CSDN博客
从功能上来看,它们的作用是相同的,都是用来重塑 Tensor 的 shape 的。
view 只适合对满足连续性条件 (contiguous) 的 Tensor进行操作,而reshape 同时还可以对不满足连续性条件的 Tensor 进行操作,具有更好的鲁棒性。view 能干的 reshape都能干,如果 view 不能干就可以用 reshape 来处理。
【维度增加】unsqueeze()、torch.expand() 、einops.repeat() 【维度减少】squeeze() 、einops.reduce()
★不推荐使用:torch.expand() 、
squeeze()
:对 tensor 进行维度的压缩,去掉维数为1
的维度。用法:torch.squeeze(a)
将 a 中所有为 1 的维度都删除,或者a.squeeze(1)
是去掉a
中指定的维数为1
的维度。unsqueeze()
:对数据维度进行扩充,给指定位置加上维数为1
的维度。用法:torch.unsqueeze(a, N)
,或者a.unsqueeze(N)
,在a
中指定位置N
加上一个维数为1
的维度。
【pytorch】torch.unsqueeze() 和 torch.squeeze()==>扩充维度和降维_torch扩展维度_马鹏森的博客-CSDN博客
expand() 函数只能将size=1的维度扩展到更大的尺寸,如果扩展其他维度会报错。
【维度之间交换】transpose()、 permute()、einops.rearrange()
★推荐使用:permute()、einops.rearrange()
torch.transpose()
只能交换两个维度,而 .permute()
可以自由交换任意位置。函数定义如下:
transpose(dim0, dim1) → Tensor # See torch.transpose()
permute(*dims) → Tensor # dim(int). Returns a view of the original tensor with its dimensions permuted.
在 CNN
模型中,我们经常遇到交换维度的问题,
举例:四个维度表示的 tensor:[batch, channel, h, w]
(nchw
),如果想把 channel
放到最后去,形成[batch, h, w, channel]
(nhwc
),如果使用 torch.transpose()
方法,至少要交换两次(先 1 3
交换再 1 2
交换),而使用 .permute()
方法只需一次操作,更加方便。例子程序如下:
import torch
input = torch.rand(1,3,28,32) # torch.Size([1, 3, 28, 32]
print(b.transpose(1, 3).shape) # torch.Size([1, 32, 28, 3])
print(b.transpose(1, 3).transpose(1, 2).shape) # torch.Size([1, 28, 32, 3])print(b.permute(0,2,3,1).shape) # torch.Size([1, 28, 32, 3]
三、合并分割
torch.flatten()
和torch.cat()
两个函数具有截然不同的功能。flatten()
重点在于将一个高维的复杂输入变成一维的线性向量,而cat()
重点在于将多个张量在某个维度上拼接为一个更大的张量。
【将任意维度张量转为一维张量】
torch.flatten()函数_马鹏森的博客-CSDN博客
torch.flatten(input, start_dim=0, end_dim=-1)
【维度拼接】torch.cat 和 torch.stack
【PyTorch】torch.cat==>张量拼接,在图像的应用上可以有效利用原始图像结构信息_特征图进行cat_马鹏森的博客-CSDN博客
np.stack()函数详解 ==>堆叠 【类似于torch.stack()】_马鹏森的博客-CSDN博客
可以用 torch.cat
方法和 torch.stack
方法将多个张量合并,torch.cat
和 torch.stack
有略微的区别,
torch.cat
是连接,不会增加一个新的维度,而 torch.stack
是堆叠,会增加一个新的维度。两者函数定义如下:
# Concatenates the given sequence of seq tensors in the given dimension. All tensors must either have the same shape (except in the concatenating dimension) or be empty.
torch.cat(tensors, dim=0, *, out=None) → Tensor
# Concatenates a sequence of tensors along **a new** dimension. All tensors need to be of the same size.
torch.stack(tensors, dim=0, *, out=None) → Tensor
【维度分割】torch.split 和 torch.chunk
torch.split()
和 torch.chunk()
【在给定维度(轴)上将输入张量进行分块】可以看作是 torch.cat()
的逆运算。split()
作用是将张量拆分为多个块,每个块都是原始张量的视图。split()
函数定义如下:
"""
Splits the tensor into chunks. Each chunk is a view of the original tensor.
If split_size_or_sections is an integer type, then tensor will be split into equally sized chunks (if possible). Last chunk will be smaller if the tensor size along the given dimension dim is not divisible by split_size.
If split_size_or_sections is a list, then tensor will be split into len(split_size_or_sections) chunks with sizes in dim according to split_size_or_sections.
"""
torch.split(tensor, split_size_or_sections, dim=0)
chunk()
作用是将 tensor
按 dim
(行或列)分割成 chunks
个 tensor
块,返回的是一个元组。chunk()
函数定义如下:
torch.chunk(input, chunks, dim=0) → List of Tensors
"""
Splits a tensor into a specific number of chunks. Each chunk is a view of the input tensor.
Last chunk will be smaller if the tensor size along the given dimension dim is not divisible by chunks.
Parameters:input (Tensor) – the tensor to splitchunks (int) – number of chunks to returndim (int) – dimension along which to split the tensor
"""