【PyTorch】6.张量形状操作:在深度学习的 “魔方” 里,玩转张量形状

devtools/2025/2/3 8:14:39/

       

目录

1. reshape 函数的用法

2. transpose 和 permute 函数的使用

4. squeeze 和 unsqueeze 函数的用法

5. 小节


个人主页:Icomi

专栏地址:PyTorch入门

在深度学习蓬勃发展的当下,PyTorch 是不可或缺的工具。它作为强大的深度学习框架,为构建和训练神经网络提供了高效且灵活的平台。神经网络作为人工智能的核心技术,能够处理复杂的数据模式。通过 PyTorch,我们可以轻松搭建各类神经网络模型,实现从基础到高级的人工智能应用。接下来,就让我们一同走进 PyTorch 的世界,探索神经网络人工智能的奥秘。本系列为PyTorch入门文章,若各位大佬想持续跟进,欢迎与我交流互关。

         咱们已经学习了张量的花式索引操作,它就像一把精巧的工具,让我们能够在数据的 “宝藏库” 里精准地提取和修改信息。我们接下来要学习—— 掌握对张量形状的操作。

        想象一下,我们即将搭建的网络模型就像一座宏伟而复杂的建筑,而数据则是构建这座建筑的基石。这些数据在我们的深度学习世界里,都是以张量的形式存在。在这座 “建筑” 中,不同的楼层(网络层)有着不同的功能和设计,它们之间的数据传递和运算就如同建筑中不同楼层之间的物资运输和协作。

        每一层网络对数据的处理方式都不尽相同,这就导致数据在网络层与层之间流动时,会以不同的形状(shape)进行表现和运算。比如说,有的层可能接收的是二维的张量数据,经过处理后输出一个三维的张量,就像把方形的积木经过加工变成了一个立体的模型

        如果我们不掌握对张量形状的操作,就好比一个建筑工人不熟悉不同建筑材料的尺寸和拼接方式,那么在搭建这座 “网络建筑” 时,各层之间的数据连接就会出现问题,就像积木无法正确拼接,最终导致整个建筑摇摇欲坠。

        为了能够更好地处理网络各层之间的数据连接,顺利搭建出稳固而强大的网络模型,掌握对张量形状的操作就显得尤为重要。接下来,我们就一同深入学习如何巧妙地调整和管理张量的形状,让我们在深度学习的建筑之路上稳步前行。

1. reshape 函数的用法

reshape 函数可以在保证张量数据不变的前提下改变数据的维度,将其转换成指定的形状,在后面的神经网络学习时,会经常使用该函数来调节数据的形状,以适配不同网络层之间的数据传递。

python">import torch
import numpy as npdef tensor_shape_operations():# 创建一个二维张量tensor = torch.tensor([[10, 20, 30], [40, 50, 60]])# 1. 使用 shape 属性或者 size 方法都可以获得张量的形状print(f"使用 shape 属性获取的形状: {tensor.shape},第 0 维大小: {tensor.shape[0]},第 1 维大小: {tensor.shape[1]}")print(f"使用 size 方法获取的形状: {tensor.size()},第 0 维大小: {tensor.size(0)},第 1 维大小: {tensor.size(1)}")# 2. 使用 reshape 函数修改张量形状reshaped_tensor = tensor.reshape(1, 6)print(f"修改形状后的张量形状: {reshaped_tensor.shape}")if __name__ == '__main__':tensor_shape_operations()

需要注意的是,转换前后的两个形状元素个数要相同

python">import torchdef test():torch.manual_seed(0)data = torch.randint(0, 10, [4, 5])# 查看张量的形状print(data.shape, data.shape[0], data.shape[1])print(data.size(), data.size(0), data.size(1))# 修改张量的形状new_data = data.reshape(2, 10)print(new_data)# 注意: 转换之后的形状元素个数得等于原来张量的元素个数# new_data = data.reshape(1, 10)# print(new_data)# 使用-1代替省略的形状new_data = data.reshape(5, -1)print(new_data)new_data = data.reshape(-1, 2)print(new_data)if __name__ == '__main__':test()

2. transpose 和 permute 函数的使用

transpose 函数可以实现交换张量形状的指定维度, 例如: 一个张量的形状为 (2, 3, 4) 可以通过 transpose 函数把 3 和 4 进行交换, 将张量的形状变为 (2, 4, 3)

permute 函数可以一次交换更多的维度。

python">import torch
import numpy as npdef test():data = torch.tensor(np.random.randint(0, 10, [3, 4, 5]))print('data shape:', data.size())# 1. 交换1和2维度new_data = torch.transpose(data, 1, 2)print('data shape:', new_data.size())# 2. 将 data 的形状修改为 (4, 5, 3)new_data = torch.transpose(data, 0, 1)new_data = torch.transpose(new_data, 1, 2)print('new_data shape:', new_data.size())# 3. 使用 permute 函数将形状修改为 (4, 5, 3)new_data = torch.permute(data, [1, 2, 0])print('new_data shape:', new_data.size())if __name__ == '__main__':test()

4. squeeze 和 unsqueeze 函数的用法

squeeze 函数用删除 shape 为 1 的维度,unsqueeze 在每个维度添加 1, 以增加数据的形状

python">import torch
import numpy as npdef test():data = torch.tensor(np.random.randint(0, 10, [1, 3, 1, 5]))print('data shape:', data.size())# 1. 去掉值为1的维度new_data = data.squeeze()print('new_data shape:', new_data.size())  # torch.Size([3, 5])# 2. 去掉指定位置为1的维度,注意: 如果指定位置不是1则不删除new_data = data.squeeze(2)print('new_data shape:', new_data.size())  # torch.Size([3, 5])# 3. 在2维度增加一个维度new_data = data.unsqueeze(-1)print('new_data shape:', new_data.size())  # torch.Size([3, 1, 5, 1])if __name__ == '__main__':test()

5. 小节

本小节我们学习了经常使用的关于张量形状的操作,我们用到的主要函数有:

  1. reshape 函数可以在保证张量数据不变的前提下改变数据的维度.
  2. transpose 函数可以实现交换张量形状的指定维度, permute 可以一次交换更多的维度.
  3. view 函数也可以用于修改张量的形状, 但是它要求被转换的张量内存必须连续,所以一般配合 contiguous 函数使用.
  4. squeeze 和 unsqueeze 函数可以用来增加或者减少维度.

http://www.ppmy.cn/devtools/155665.html

相关文章

neo4j入门

文章目录 neo4j版本说明部署安装Mac部署docker部署 neo4j web工具使用数据结构图数据库VS关系数据库 neo4j neo4j官网Neo4j是用ava实现的开源NoSQL图数据库。Neo4作为图数据库中的代表产品,已经在众多的行业项目中进行了应用,如:网络管理&am…

Vue.js `v-memo` 性能优化技巧

Vue.js v-memo 性能优化技巧 今天我们来聊聊 Vue 3.2 引入的一个性能优化指令:v-memo。如果你在处理大型列表或复杂组件时,遇到性能瓶颈,那么 v-memo 可能会成为你的得力助手。 什么是 v-memo? v-memo 是 Vue 3.2 新增的内置指…

Vue.js组件开发-实现全屏图片文字缩放切换特效

使用 Vue 实现全屏图片文字缩放切换特效 步骤 创建 Vue 项目:使用 Vue CLI 来快速创建一个新的 Vue 项目。设计组件结构:创建一个包含图片和文字的组件,并实现缩放和切换效果。实现样式:使用 CSS 来实现全屏显示、缩放和切换动画…

HarmonyOS:ForEach:循环渲染

一、前言 ForEach接口基于数组类型数据来进行循环渲染,需要与容器组件配合使用,且接口返回的组件应当是允许包含在ForEach父容器组件中的子组件。例如,ListItem组件要求ForEach的父容器组件必须为List组件。 API参数说明见:ForEa…

matlab提取滚动轴承故障特征

为了精准、稳定地提取滚动轴承故障特征,提出了基于变分模态分解和奇异值分解的特征提取方法,采用标准模糊C均值聚类(fuzzy C means clustering, FCM)进行故障识 别。对同一负荷下的已知故障信号进行变分模态分解,利用 奇异值分解技术进一步提…

PyQt5超详细教程终篇

PyQt5超详细教程 前言 接: [【Python篇】PyQt5 超详细教程——由入门到精通(序篇)](【Python篇】PyQt5 超详细教程——由入门到精通(序篇)-CSDN博客) 建议把代码复制到pycahrm等IDE上面看实际效果,方便理…

c++可变参数详解

目录 引言 库的基本功能 va_start 宏: va_arg 宏 va_end 宏 va_copy 宏 使用 处理可变参数代码 C11可变参数模板 基本概念 sizeof... 运算符 包扩展 引言 在C编程中,处理不确定数量的参数是一个常见的需求。为了支持这种需求,C标准库提供了 &…

Flask数据的增删改查(CRUD)_flask删除数据自动更新

查询年龄小于17的学生信息 Student.query.filter(Student.s_age < 17) students Student.query.filter(Student.s_age.__lt__(17))模糊查询&#xff0c;使用like&#xff0c;查询姓名中第二位为花的学生信息 like ‘_花%’,_代表必须有一个数据&#xff0c;%任何数据 st…