Pytorch cat()与stack()函数详解

news/2024/9/22 19:59:09/

torch.cat()

cat为concatenate的缩写,意思为拼接,torch.cat()函数一般是用于张量拼接使用的

cat(tensors: Union[Tuple[Tensor, ...], List[Tensor]], dim: _int = 0, *, out: Optional[Tensor] = None) -> Tensor:

可以看到cat()函数的参数,常用的参数为,第一个参数:可以选择元组或者列表,内部包含需要拼接的张量,需要按照顺序排列,第二个参数为dim,用于指定需要拼接的维度

python">import torch
import numpy as npdata1 = torch.randint(0, 10, [2, 3, 4])
data2 = torch.randint(0, 10, [2, 3, 4])print(data1)
print(data2)
print("-" * 20)print(torch.cat([data1, data2], dim=0))
print(torch.cat([data1, data2], dim=1))
print(torch.cat([data1, data2], dim=2))
# tensor([[[9, 4, 0, 0],
#          [3, 3, 7, 6],
#          [6, 1, 0, 8]],
# 
#         [[9, 1, 1, 2],
#          [1, 0, 6, 4],
#          [7, 9, 3, 9]]])
# tensor([[[3, 2, 6, 3],
#          [8, 3, 1, 1],
#          [0, 9, 2, 5]],
# 
#         [[2, 6, 7, 5],
#          [9, 1, 0, 1],
#          [0, 6, 4, 4]]])
# --------------------
# tensor([[[9, 4, 0, 0],
#          [3, 3, 7, 6],
#          [6, 1, 0, 8]],
# 
#         [[9, 1, 1, 2],
#          [1, 0, 6, 4],
#          [7, 9, 3, 9]],
# 
#         [[3, 2, 6, 3],
#          [8, 3, 1, 1],
#          [0, 9, 2, 5]],
# 
#         [[2, 6, 7, 5],
#          [9, 1, 0, 1],
#          [0, 6, 4, 4]]])
# tensor([[[9, 4, 0, 0],
#          [3, 3, 7, 6],
#          [6, 1, 0, 8],
#          [3, 2, 6, 3],
#          [8, 3, 1, 1],
#          [0, 9, 2, 5]],
# 
#         [[9, 1, 1, 2],
#          [1, 0, 6, 4],
#          [7, 9, 3, 9],
#          [2, 6, 7, 5],
#          [9, 1, 0, 1],
#          [0, 6, 4, 4]]])
# tensor([[[9, 4, 0, 0, 3, 2, 6, 3],
#          [3, 3, 7, 6, 8, 3, 1, 1],
#          [6, 1, 0, 8, 0, 9, 2, 5]],
# 
#         [[9, 1, 1, 2, 2, 6, 7, 5],
#          [1, 0, 6, 4, 9, 1, 0, 1],
#          [7, 9, 3, 9, 0, 6, 4, 4]]])

上述代码演示了拼接维度为0,1,2的时候的结果,可以看出cat()并不会影响张量的维度,如上述的三维张量拼接,若dim为0则按块(后两位张量组成的二维张量)进行拼接,若dim为1则按行拼接,若dim为2则按列拼接

torch.stack()

stack为堆叠、栈的意思

stack(tensors: Union[Tuple[Tensor, ...], List[Tensor]], dim: _int = 0, *, out: Optional[Tensor] = None) -> Tensor: 

可以看到stack()和cat()的用法几乎一致,都是用于堆叠张量组成的列表或元组,以及堆叠的维度dim

python">import torch
import numpy as npdata1 = torch.randint(0, 10, [2, 3, 4])
data2 = torch.randint(0, 10, [2, 3, 4])print(data1)
print(data2)
print("-" * 20)data3 = torch.stack([data1, data2], dim=0)
data4 = torch.stack([data1, data2], dim=1)
data5 = torch.stack([data1, data2], dim=2)
data6 = torch.stack([data1, data2], dim=3)
print(data3.shape)
print(data3)
print(data4.shape)
print(data4)
print(data5.shape)
print(data5)
print(data6.shape)
print(data6)# tensor([[[1, 6, 6, 1],
#          [3, 1, 8, 2],
#          [0, 4, 7, 3]],
# 
#         [[4, 7, 5, 6],
#          [5, 4, 0, 2],
#          [8, 0, 3, 0]]])
# tensor([[[5, 2, 7, 2],
#          [7, 4, 2, 0],
#          [8, 5, 5, 9]],
# 
#         [[7, 1, 5, 6],
#          [3, 5, 4, 7],
#          [1, 0, 8, 8]]])
# --------------------
# torch.Size([2, 2, 3, 4])
# tensor([[[[1, 6, 6, 1],
#           [3, 1, 8, 2],
#           [0, 4, 7, 3]],
# 
#          [[4, 7, 5, 6],
#           [5, 4, 0, 2],
#           [8, 0, 3, 0]]],
# 
# 
#         [[[5, 2, 7, 2],
#           [7, 4, 2, 0],
#           [8, 5, 5, 9]],
# 
#          [[7, 1, 5, 6],
#           [3, 5, 4, 7],
#           [1, 0, 8, 8]]]])
# torch.Size([2, 2, 3, 4])
# tensor([[[[1, 6, 6, 1],
#           [3, 1, 8, 2],
#           [0, 4, 7, 3]],
# 
#          [[5, 2, 7, 2],
#           [7, 4, 2, 0],
#           [8, 5, 5, 9]]],
# 
# 
#         [[[4, 7, 5, 6],
#           [5, 4, 0, 2],
#           [8, 0, 3, 0]],
# 
#          [[7, 1, 5, 6],
#           [3, 5, 4, 7],
#           [1, 0, 8, 8]]]])
# torch.Size([2, 3, 2, 4])
# tensor([[[[1, 6, 6, 1],
#           [5, 2, 7, 2]],
# 
#          [[3, 1, 8, 2],
#           [7, 4, 2, 0]],
# 
#          [[0, 4, 7, 3],
#           [8, 5, 5, 9]]],
# 
# 
#         [[[4, 7, 5, 6],
#           [7, 1, 5, 6]],
# 
#          [[5, 4, 0, 2],
#           [3, 5, 4, 7]],
# 
#          [[8, 0, 3, 0],
#           [1, 0, 8, 8]]]])
# torch.Size([2, 3, 4, 2])
# tensor([[[[1, 5],
#           [6, 2],
#           [6, 7],
#           [1, 2]],
# 
#          [[3, 7],
#           [1, 4],
#           [8, 2],
#           [2, 0]],
# 
#          [[0, 8],
#           [4, 5],
#           [7, 5],
#           [3, 9]]],
# 
# 
#         [[[4, 7],
#           [7, 1],
#           [5, 5],
#           [6, 6]],
# 
#          [[5, 3],
#           [4, 5],
#           [0, 4],
#           [2, 7]],
# 
#          [[8, 1],
#           [0, 0],
#           [3, 8],
#           [0, 8]]]])

可以看到dim设置为几,就会按第几个维度进行堆叠拼接,dim为0则是整体堆叠后升维,dim为1则是按第二个维度也就是后两维张量为一个整体进行两个张量对应堆叠拼接,dim为2为按后两维中的行进行堆叠拼接,dim为3也就是按两个张量的单个值进行对应堆叠拼接

stack()随着维度增加,理解会较为复杂,具体可见代码和结果演示

注意,cat()和stack()中的dim参数也可以使用负索引,即从-1开始进行维度索引


http://www.ppmy.cn/news/1512347.html

相关文章

掉头发特别厉害的日子要来了!用对这3个方法,让头发重新乌黑浓密起来!

最近天气转凉,马上就要迎来处暑,正式进入秋季! 很多人都有这样一个感受:进入秋天后,就特别容易掉头发,不管洗头、梳头还是睡觉,一抓头发总会掉几根甚至更多。 枕头上、沙发上、地板上.....头发遍…

微软运行库全集合:一站式解决兼容性问题

开发者在部署应用程序时经常遇到因缺少运行库而引发的兼容性问题。为了解决这一问题,电脑天空推荐微软常用运行库合集,一个集成了微软多个关键运行库组件的软件包。 📚 包含组件概览: Visual Basic Virtual Machine:…

【数据结构】线段树 需要pushdown

建树lrpushup单点修改lrpushup区间查询包一旦题目中pushdown就必须pushdown&#xff0c;否则也无需区间修改包pushuppushdown AcWing 243. 一个简单的整数问题2 - AcWing #include<iostream> using namespace std; #define ll long long struct Tree{int l,r;ll sum,ad…

收银系统源码-连锁店解决方案

千呼新零售2.0系统由零售行业连锁店一体化收银系统和多商户入驻平台商城两个板块组成&#xff0c;打造门店平台的本地生活即时零售模式。 其中连锁店收银系统包括线下收银私域商城连锁店管理ERP管理商品管理供应商管理会员营销等功能为一体&#xff0c;线上线下数据全部打通。…

[MRCTF2020]套娃1

打开题目&#xff0c;查看源代码&#xff0c;有提示 有两层过滤 1.过滤"_"与"%5f" 。 这里要求的参数必须是"b_u_p_t"但是不能检测出"_"。这里看着很作弄人。其实这里要用到php里非法参数名的问题。可以参考一下博客 ?b.u.p.t2333…

字典树(Trie)

Trie字符串统计 描述 维护一个字符串集合&#xff0c;支持两种操作&#xff1a; “I x”向集合中插入一个字符串x&#xff1b;“Q x”询问一个字符串在集合中出现了多少次。 共有N个操作&#xff0c;输入的字符串总长度不超过105105&#xff0c;字符串仅包含小写英文字母。…

Mysql磁盘满问题

Temporary file write failure show processlist;kill id mysql创建索引导致死锁&#xff0c;数据库崩溃&#xff0c;mysql的表级锁之【元数据锁&#xff08;meta data lock&#xff0c;MDL)】全解_metadata_locks_秃了也弱了。的博客-CSDN博客 多个waiting for handler commit…

Linux ubuntu 24.04 运行《文明5》游戏,解决游戏中文设置的问题!

Linux ubuntu 24.04 运行《文明5》游戏&#xff0c;解决游戏中文设置的问题&#xff01; 《文明5》是一款回合制经营策略游戏&#xff0c;拼的就是科技发展速度&#xff0c;点的是科技树&#xff0c;抢的就是科技制高点&#xff0c;但是真的是时间漫长&#xff0c;可能需要好几…