深度学习02-pytorch-06-张量的形状操作

embedded/2024/9/23 2:50:34/

在 PyTorch 中,张量的形状操作是非常重要的,可以让你灵活地调整和处理张量的维度和数据结构。以下是一些常用的张量形状函数及其用法,带有详细解释和举例说明:

1. reshape()

功能: 改变张量的形状,但不改变数据的顺序。

语法: tensor.reshape(*shape)

示例:

import torch
x = torch.tensor([[1, 2, 3], [4, 5, 6]])
x_reshaped = x.reshape(3, 2)
print(x_reshaped)

输出:

tensor([[1, 2],[3, 4],[5, 6]])

在这个例子中,张量 x 被从形状 (2, 3) 重塑为 (3, 2)

2. squeeze()

功能: 去除张量中大小为1的维度(例如,形状是 (1, 3, 1) 会变成 (3))。

语法: tensor.squeeze(dim=None)

示例:

x = torch.tensor([[[1, 2, 3]]])  # shape: (1, 1, 3)
x_squeezed = x.squeeze()
print(x_squeezed)

输出:

tensor([1, 2, 3])

在这个例子中,squeeze() 去除了前两个大小为1的维度。

3. unsqueeze()

功能: 在指定的维度插入大小为1的新维度。

语法: tensor.unsqueeze(dim)

示例:

x = torch.tensor([1, 2, 3])  # shape: (3,)
x_unsqueezed = x.unsqueeze(0)  # 插入新的0维
print(x_unsqueezed.shape)  # 输出: torch.Size([1, 3])

在这个例子中,

unsqueeze(0) 在第0个维度插入一个新的大小为1的维度,将形状从 (3,) 变成 (1, 3)

4. transpose()

功能: 交换张量的两个维度。

语法: tensor.transpose(dim0, dim1)

示例:

x = torch.tensor([[1, 2, 3], [4, 5, 6]])  # shape: (2, 3)
x_transposed = x.transpose(0, 1)
print(x_transposed)

输出:

tensor([[1, 4],[2, 5],[3, 6]])

在这个例子中,transpose(0, 1) 交换了维度0和维度1,使张量的形状从 (2, 3) 变成 (3, 2)

5. permute()

功能: 改变张量的维度顺序,允许对多个维度进行交换。

语法: tensor.permute(*dims)

示例:

x = torch.randn(2, 3, 5)  # shape: (2, 3, 5)
x_permuted = x.permute(2, 0, 1)
print(x_permuted.shape)  # 输出: torch.Size([5, 2, 3])

在这个例子中,permute(2, 0, 1) 重新排列了维度顺序,

使得形状从 (2, 3, 5) 变为 (5, 2, 3)

6. view()

功能: 类似于 reshape(),但是 view() 需要张量在内存中是连续的。

语法: tensor.view(*shape)

示例:

x = torch.tensor([[1, 2, 3], [4, 5, 6]])
x_viewed = x.view(3, 2)
print(x_viewed)

输出:

tensor([[1, 2],[3, 4],[5, 6]])

view() 的使用需要张量是连续的,否则会报错。

关于连续性,可以结合 contiguous() 使用。

7. contiguous()

功能: 将非连续的张量转换为在内存中连续存储的张量。

语法: tensor.contiguous()

示例:

x = torch.randn(2, 3, 5)
x_permuted = x.permute(2, 0, 1)  # 这使得张量不再连续
x_contiguous = x_permuted.contiguous().view(5, 6)  # 转换为连续后再进行view操作
print(x_contiguous.shape)

permute() 操作后的张量不一定是连续的,因此需要 contiguous() 来保证可以使用 view()

8. expand()repeat()

功能: 扩展张量到更高的维度。

  • expand() 只是广播,不复制内存。

  • repeat() 会实际复制数据。

示例:

x = torch.tensor([1, 2, 3])
x_expanded = x.expand(3, 3)  # 广播
x_repeated = x.repeat(3, 1)  # 重复数据
print(x_expanded)
print(x_repeated)

输出:

tensor([[1, 2, 3],[1, 2, 3],[1, 2, 3]])
​
tensor([[1, 2, 3],[1, 2, 3],[1, 2, 3]])

区别在于 expand() 不会占用更多内存,而 repeat() 会真正复制数据。

总结

上述这些张量操作函数在处理多维数据时非常有用,能够灵活地调整和转换张量的形状,以便进行各种操作和模型设计。


http://www.ppmy.cn/embedded/115364.html

相关文章

简单了解 JVM

目录 ♫什么是JVM ♫JVM的运行流程 ♫JVM运行时数据区 ♪虚拟机栈 ♪本地方法栈 ♪堆 ♪程序计数器 ♪方法区/元数据区 ♫类加载的过程 ♫双亲委派模型 ♫垃圾回收机制 ♫什么是JVM JVM 是 Java Virtual Machine 的简称,意为 Java虚拟机。 虚拟机是指通过软件模…

高级java每日一道面试题-2024年9月17日-框架篇-什么是ORM框架?

如果有遗漏,评论区告诉我进行补充 面试官: 如何处理事务中的性能问题? 我回答: 在Java高级面试中,理解ORM(Object-Relational Mapping,对象关系映射)框架是非常重要的。ORM框架是一种编程技术,用于将面向…

如何在微服务的日志中记录每个接口URL、状态码和耗时信息?

一、实现方式 1.直接通过SpringCloud-GateWay 的GlobalFilter实现 2.AOP反射自定义注解自己封装 二、具体实现 1.自定义注解 Target({ElementType.METHOD})//作用在方法上 Retention(RetentionPolicy.RUNTIME)//运行时生效 public interface MethodExporter{//自定义注解只…

Python知识点:如何使用Python进行算法交易

开篇,先说一个好消息,截止到2025年1月1日前,翻到文末找到我,赠送定制版的开题报告和任务书,先到先得!过期不候! 使用Python进行算法交易的完整指南 在当今快节奏的金融市场中,算法…

systemctl控制服务和守护进程

system守护进程介绍: systemd daemon(守护进程)管理linux的启动,包括服务的启动和管理 systemd可在系统引导时以及运行中的系统上激活系统资源、服务器守护进程和其他进程。 守护进程daemon是在后台运行或等待的进程,以执行不同的任务。通常daemon在系统启动时…

Android架构组件: MVVM模式的实战应用与数据绑定技巧

随着Android应用的复杂性增加,开发人员面临代码重用性、可维护性和扩展性问题。为了解决这些问题,谷歌推出了Android架构组件(Android Architecture Components),这套框架能帮助构建高效、可维护的应用。MVVM&#xff…

基于YOLOv5的教室人数检测统计系统

基于YOLOv5的教室人数检测统计系统可以有效地用于监控教室内的学生数量,适用于多种应用场景,比如 自动考勤、安全监控或空间利用分析 以下是如何构建这样一个系统的概述,包括环境准备、数据集创建、模型训练以及如何处理不同类型的媒体输入…

DNF Decouple and Feedback Network for Seeing in the Dark

DNF: Decouple and Feedback Network for Seeing in the Dark 在深度学习领域,尤其是在低光照图像增强的应用中,RAW数据的独特属性展现出了巨大的潜力。然而,现有架构在单阶段和多阶段方法中都存在性能瓶颈。单阶段方法由于域歧义&#xff0c…