在张量操作中,unsqueeze,squeeze,reshape,view
1. unsqueeze
- 功能: 在指定维度上增加一个新的维度,通常用于将一维张量扩展为二维,以便符合批处理的要求。
- 举例:
python">import torchx = torch.tensor([1, 2, 3]) # Shape: (3,) y = x.unsqueeze(0) # Shape: (1, 3) print(y)
- 数值示例:
x
是[1, 2, 3]
,形状为(3,)
。使用unsqueeze(0)
后,形状变为(1, 3)
,结果为[[1, 2, 3]]
。
- 数值示例:
2. expand
- 功能: 通过复制数据来扩展张量的尺寸,而不增加新的维度。它不会在内存中实际复制数据,而是通过广播机制实现。
- 举例:
python">import torchx = torch.tensor([1, 2, 3]) # Shape: (3,) y = x.unsqueeze(0) # Shape: (1, 3) z = y.expand(3, 3) # Shape: (3, 3) print(z)
- 数值示例:
y
是[[1, 2, 3]]
,形状为(1, 3)
。使用expand(3, 3)
后,形状变为(3, 3)
,结果为[[1, 2, 3], [1, 2, 3], [1, 2, 3]]
。
- 数值示例:
3. reshape
- 功能: 改变张量的形状,只要元素的总数保持不变。
- 举例:
python">import torchx = torch.tensor([1, 2, 3, 4]) # Shape: (4,) y = x.reshape(2, 2) # Shape: (2, 2) print(y)
- 数值示例:
x
是[1, 2, 3, 4]
,形状为(4,)
。使用reshape(2, 2)
后,形状变为(2, 2)
,结果为[[1, 2], [3, 4]]
。
- 数值示例:
4. view
- 功能: 类似于
reshape
,但要求张量在内存中是连续的。它改变张量的形状,但不改变数据的存储方式。 - 举例:
python">import torchx = torch.tensor([1, 2, 3, 4]) # Shape: (4,) y = x.view(2, 2) # Shape: (2, 2) print(y)
- 数值示例:
x
是[1, 2, 3, 4]
,形状为(4,)
。使用view(2, 2)
后,形状变为(2, 2)
,结果为[[1, 2], [3, 4]]
。
- 数值示例:
5. squeeze
- 功能: 去除张量中尺寸为 1 的维度。
- 举例:
python">import torchx = torch.tensor([[[1, 2, 3]]]) # Shape: (1, 1, 3) y = x.squeeze() # Shape: (3,) print(y)
- 数值示例:
x
是[[[1, 2, 3]]]
,形状为(1, 1, 3)
。使用squeeze()
后,形状变为(3,)
,结果为[1, 2, 3]
。
- 数值示例:
总结:
unsqueeze
: 增加一个新维度。示例:从[1, 2, 3]
变为[[1, 2, 3]]
。expand
: 扩展现有维度的大小。示例:从[[1, 2, 3]]
变为[[1, 2, 3], [1, 2, 3], [1, 2, 3]]
。reshape
: 重新排列数据的形状。示例:从[1, 2, 3, 4]
变为[[1, 2], [3, 4]]
。view
: 与reshape
类似,但有内存连续性要求。示例:从[1, 2, 3, 4]
变为[[1, 2], [3, 4]]
。squeeze
: 删除尺寸为 1 的维度。示例:从[[[1, 2, 3]]]
变为[1, 2, 3]
。