在 PyTorch 中,squeeze()
函数用于压缩张量中的尺寸为 1 的维度。当张量中存在尺寸为 1 的维度时,squeeze()
函数可以将这些尺寸为 1 的维度去除,从而减少张量的维度。
在 squeeze()
函数中,可以传入一个参数来指定要压缩的维度。如果指定了这个参数,则只会对指定的维度进行压缩;如果没有指定参数,则会压缩所有尺寸为 1 的维度。
而 (2)
则是指定要压缩的维度,这里的 2
表示要压缩的维度的索引,即将维度索引为 2 的尺寸为 1 的维度去除。
举例来说,如果有一个形状为 (3, 1, 2)
的张量 x
,其中第二个维度的尺寸为 1,可以使用 squeeze(1)
来去除这个尺寸为 1 的维度。如果没有指定参数,则会去除所有尺寸为 1 的维度。
python">import torch# 创建一个形状为 (3, 1, 2) 的张量
x = torch.randn(3, 1, 2)# 压缩维度为 1 的尺寸为 1 的维度
y = x.squeeze(1)print(y.shape) # 输出:torch.Size([3, 2])
在上面的例子中,squeeze(1)
压缩了张量 x
中第二个维度的尺寸为 1 的维度,最终得到的张量 y
的形状为 (3, 2)
。