torch.cat 和 torch.stack看起来相似但是性质还是不同的
使用python中的list列表收录tensor时,然后将list列表转化成tensor时,会报错。这个时候就要使用torch.stack进行堆叠,转化成tensor。
- torch.cat()
torch.cat(tensors,dim=0,out=None)→ Tensor
torch.cat()对tensors沿指定维度拼接,但返回的Tensor的维数不会变
import torch
a = torch.rand((2, 3))
b = torch.rand((2, 3))
c = torch.cat((a, b))
a.size(), b.size(), c.size()
(torch.Size([2, 3]), torch.Size([2, 3]), torch.Size([4, 3]))
可以看到c和a、b一样都是二维的。
- torch.stack()
torch.stack(tensors,dim=0,out=None)→ Tensor
torch.stack()同样是对tensors沿指定维度拼接,但返回的Tensor会多一维
import torch
a = torch.rand((2, 3))
b = torch.rand((2, 3))
c = torch.stack((a, b))
a.size(), b.size(), c.size()
(torch.Size([2, 3]), torch.Size([2, 3]), torch.Size([2, 2, 3]))
可以看到c是三维的,比a、b多了一维。