深度学习02-pytorch-01-张量形状的改变

embedded/2024/9/23 2:21:01/

在 PyTorch 中,张量的形状(或称为形状变换)可以通过多种方式进行改变,这有助于数据的重新排列、打平、扩展或压缩。常用的操作包括 view(), reshape(), transpose(), unsqueeze(), squeeze(), 和 permute() 等等。下面将详细介绍这些常见的形状改变方法。

1. view()

view() 是 PyTorch 中常用的形状变换函数,它可以改变张量的形状,但要求张量在内存中是连续存储的。

import torch
tensor = torch.randn(2, 3, 4)  # 原始张量形状为 (2, 3, 4)
print(tensor.size())  # 输出: torch.Size([2, 3, 4])
​
# 使用 view 变换形状
reshaped_tensor = tensor.view(6, 4)  # 变为形状 (6, 4)
print(reshaped_tensor.size())  # 输出: torch.Size([6, 4])

注意:view() 只能用于在内存中连续的张量,如果内存不连续,会报错。可以使用 tensor.contiguous() 使其连续。

2. reshape()

reshape() 功能类似于 view(),但它会自动处理张量是否连续的问题,即使张量不连续,也能重新调整形状。

reshaped_tensor = tensor.reshape(6, 4)  # 变为形状 (6, 4)
print(reshaped_tensor.size())  # 输出: torch.Size([6, 4])

3. transpose()

transpose() 交换两个维度的位置。它不改变张量的存储顺序,只是交换了维度的显示顺序。

tensor = torch.randn(2, 3)  # 形状为 (2, 3)
transposed_tensor = tensor.transpose(0, 1)  # 交换第 0 维和第 1 维
print(transposed_tensor.size())  # 输出: torch.Size([3, 2])

4. permute()

permute() 可以按照指定顺序重排张量的所有维度。它比 transpose() 更灵活。

tensor = torch.randn(2, 3, 4)  # 原始形状 (2, 3, 4)
permuted_tensor = tensor.permute(2, 0, 1)  # 将维度顺序调整为 (4, 2, 3)
print(permuted_tensor.size())  # 输出: torch.Size([4, 2, 3])

5. squeeze()

squeeze() 会移除张量中大小为 1 的维度。

tensor = torch.randn(1, 3, 1, 4)  # 形状为 (1, 3, 1, 4)
squeezed_tensor = tensor.squeeze()  # 移除所有大小为 1 的维度
print(squeezed_tensor.size())  # 输出: torch.Size([3, 4])

你也可以指定要移除的维度:

squeezed_tensor = tensor.squeeze(2)  # 只移除第 2 维(大小为 1)

6. unsqueeze()

unsqueeze() 用来在指定的位置增加一个大小为 1 的维度。它与 squeeze() 相反。

tensor = torch.randn(3, 4)  # 形状为 (3, 4)
unsqueezed_tensor = tensor.unsqueeze(0)  # 在第 0 维添加一个大小为 1 的维度
print(unsqueezed_tensor.size())  # 输出: torch.Size([1, 3, 4])

7. flatten()

flatten() 用来将多维张量展平成一个一维张量。

tensor = torch.randn(2, 3, 4)  # 形状为 (2, 3, 4)
flattened_tensor = tensor.flatten()  # 展平成一维张量
print(flattened_tensor.size())  # 输出: torch.Size([24])

你也可以指定展平的范围:

flattened_tensor = tensor.flatten(start_dim=1)  # 从第 1 维开始展平
print(flattened_tensor.size())  # 输出: torch.Size([2, 12])

总结

  • view(): 改变张量形状,要求连续存储。

  • reshape(): 改变张量形状,处理连续与否问题。

  • transpose(): 交换两个维度。

  • permute(): 自由调整所有维度的顺序。

  • squeeze(): 移除大小为 1 的维度。

  • unsqueeze(): 添加大小为 1 的维度。

  • flatten(): 将张量展平为一维。

这些形状变换操作是 PyTorch 中常用的工具,有助于你更灵活地操作张量并适应深度模型的需求。


http://www.ppmy.cn/embedded/115351.html

相关文章

【java面经】Redis速记

目录 基本概念 string hash list set zset 常见问题及解决 缓存穿透 缓存击穿 缓存雪崩 Redis内存管理策略 noeviction allkeys-lru allkeys-random volatile-random volatile-ttl Redis持久化机制 RDB快照 AOF追加文件 Redis多线程特性 Redis应用场景 缓…

数据处理与统计分析篇-day08-apply()自定义函数与分组操作

一. 自定义函数 概述 当Pandas自带的API不能满足需求, 例如: 我们需要遍历的对Series中的每一条数据/DataFrame中的一列或一行数据做相同的自定义处理, 就可以使用Apply自定义函数 apply函数可以接收一个自定义函数, 可以将Series对象的逐个值或DataFrame的行/列数据传递给自…

用Python提取PowerPoint演示文稿中的音频和视频

将多种格式的媒体内容进行重新利用(如PowerPoint演示中的音频和视频)是非常有价值的。无论是创建独立的音频文件、提取视频以便在线分发,还是为了未来的使用需求进行资料归档,从演示文稿中提取这些媒体文件可以为多媒体内容的多次…

API 接入前的安全防线:注意事项全梳理

在当今数字化的商业环境中,API(Application Programming Interface)的广泛应用为企业带来了诸多便利,但同时也伴随着潜在的安全风险。在接入 API 之前,构建坚实的安全防线至关重要。以下是对 API 接入前安全注意事项的…

达梦disql支持上翻历史命令-安装rlwrap

time:2024/09/18 Author:skatexg 一、背景 DM安装完成后使用disql命令行,无法使用上下键引用历史命令,会出现“[[A[[A”的现象。这样的操作包括使用退格Backspace键,上下键,左右键等。解决这个问题,可以使用rlwrap工…

Python练习宝典:Day 1 - 选择题 - 基础知识

目录 一、踏上Python之旅二、Python语言基础三、流程控制语句四、序列的应用 一、踏上Python之旅 1.想要输出 I Love Python,应该使用()函数。 A.printf() B.print() C.println() D.Print()2.Python安装成功的标志是在控制台(终端)输入python/python3后,命令提示符变为: A.&…

自监督的主要学习方法

自监督学习是一种机器学习方法,其中模型从未标注的数据中学习生成标签,通常通过构造预训练任务或预测任务来从数据的内部结构中提取信息。它的核心目标是利用无监督的数据进行学习,从而在下游任务中更好地利用监督信号。自监督学习的主要方法…

QNX Hypervisor(十)Linux Guest IPC 二

上文还遗留了一个问题,就是在测试ipc的时候挂死了。相关原理我写在了另外一篇文章。 内存管理 所以导致挂死的问题就是因为没有进行地址映射,mmu无法转换。从kernel代码看,只有ram区域才会进行映射。我们的qvmconf文件也确实没有配置0xb8000000,只配置了pass。 pass loc …