深度学习02-pytorch-06-张量的形状操作

ops/2024/12/23 5:39:36/

在 PyTorch 中,张量的形状操作是非常重要的,可以让你灵活地调整和处理张量的维度和数据结构。以下是一些常用的张量形状函数及其用法,带有详细解释和举例说明:

1. reshape()

功能: 改变张量的形状,但不改变数据的顺序。

语法: tensor.reshape(*shape)

示例:

import torch
x = torch.tensor([[1, 2, 3], [4, 5, 6]])
x_reshaped = x.reshape(3, 2)
print(x_reshaped)

输出:

tensor([[1, 2],[3, 4],[5, 6]])

在这个例子中,张量 x 被从形状 (2, 3) 重塑为 (3, 2)

2. squeeze()

功能: 去除张量中大小为1的维度(例如,形状是 (1, 3, 1) 会变成 (3))。

语法: tensor.squeeze(dim=None)

示例:

x = torch.tensor([[[1, 2, 3]]])  # shape: (1, 1, 3)
x_squeezed = x.squeeze()
print(x_squeezed)

输出:

tensor([1, 2, 3])

在这个例子中,squeeze() 去除了前两个大小为1的维度。

3. unsqueeze()

功能: 在指定的维度插入大小为1的新维度。

语法: tensor.unsqueeze(dim)

示例:

x = torch.tensor([1, 2, 3])  # shape: (3,)
x_unsqueezed = x.unsqueeze(0)  # 插入新的0维
print(x_unsqueezed.shape)  # 输出: torch.Size([1, 3])

在这个例子中,

unsqueeze(0) 在第0个维度插入一个新的大小为1的维度,将形状从 (3,) 变成 (1, 3)

4. transpose()

功能: 交换张量的两个维度。

语法: tensor.transpose(dim0, dim1)

示例:

x = torch.tensor([[1, 2, 3], [4, 5, 6]])  # shape: (2, 3)
x_transposed = x.transpose(0, 1)
print(x_transposed)

输出:

tensor([[1, 4],[2, 5],[3, 6]])

在这个例子中,transpose(0, 1) 交换了维度0和维度1,使张量的形状从 (2, 3) 变成 (3, 2)

5. permute()

功能: 改变张量的维度顺序,允许对多个维度进行交换。

语法: tensor.permute(*dims)

示例:

x = torch.randn(2, 3, 5)  # shape: (2, 3, 5)
x_permuted = x.permute(2, 0, 1)
print(x_permuted.shape)  # 输出: torch.Size([5, 2, 3])

在这个例子中,permute(2, 0, 1) 重新排列了维度顺序,

使得形状从 (2, 3, 5) 变为 (5, 2, 3)

6. view()

功能: 类似于 reshape(),但是 view() 需要张量在内存中是连续的。

语法: tensor.view(*shape)

示例:

x = torch.tensor([[1, 2, 3], [4, 5, 6]])
x_viewed = x.view(3, 2)
print(x_viewed)

输出:

tensor([[1, 2],[3, 4],[5, 6]])

view() 的使用需要张量是连续的,否则会报错。

关于连续性,可以结合 contiguous() 使用。

7. contiguous()

功能: 将非连续的张量转换为在内存中连续存储的张量。

语法: tensor.contiguous()

示例:

x = torch.randn(2, 3, 5)
x_permuted = x.permute(2, 0, 1)  # 这使得张量不再连续
x_contiguous = x_permuted.contiguous().view(5, 6)  # 转换为连续后再进行view操作
print(x_contiguous.shape)

permute() 操作后的张量不一定是连续的,因此需要 contiguous() 来保证可以使用 view()

8. expand()repeat()

功能: 扩展张量到更高的维度。

  • expand() 只是广播,不复制内存。

  • repeat() 会实际复制数据。

示例:

x = torch.tensor([1, 2, 3])
x_expanded = x.expand(3, 3)  # 广播
x_repeated = x.repeat(3, 1)  # 重复数据
print(x_expanded)
print(x_repeated)

输出:

tensor([[1, 2, 3],[1, 2, 3],[1, 2, 3]])
​
tensor([[1, 2, 3],[1, 2, 3],[1, 2, 3]])

区别在于 expand() 不会占用更多内存,而 repeat() 会真正复制数据。

总结

上述这些张量操作函数在处理多维数据时非常有用,能够灵活地调整和转换张量的形状,以便进行各种操作和模型设计。


http://www.ppmy.cn/ops/114264.html

相关文章

C#描述-计算机视觉OpenCV(6):形态学

C#描述-计算机视觉OpenCV(6):形态学 前言阈值化二值图像腐蚀与膨胀算法形态学滤波器开启和闭合运算原理概括 前言 这是本系列第六节,主要是介绍基础的形态学运用。 形态学主要是分析图像中不同主题的形态,它定义了一系…

基于SpringBoot+Vue的宠物医院管理系统

作者:计算机学姐 开发技术:SpringBoot、SSM、Vue、MySQL、JSP、ElementUI、Python、小程序等,“文末源码”。 专栏推荐:前后端分离项目源码、SpringBoot项目源码、SSM项目源码 系统展示 【2025最新】基于JavaSpringBootVueMySQL的…

unity 高性能对象池解决方案

对于一个高性能对象池应该有的功能: 支持多种对象、同步加载、异步加载、隐藏方式、自动收缩(使用LRU缓存机制,最近最久未使用来进行收缩)、异步删除 所以我针对这几个功能讲一下原理: 支持多种对象: G…

Spring Session

Session 共享问题 在 Web 项目开发中,Session 会话管理是一个很重要的部分,用于存储与记录用户的状态或相关的数据。 通常情况下 session 交由容器(tomcat)来负责存储和管理,但是如果项目部署在多台 tomcat 中&#…

某文书网爬虫逆向

一、抓包分析 请求参数和响应数据都有加密 二、逆向分析 老方法、下xhr断点 加密实现逻辑都在这个方法里 执行到这的时候,在向下跟栈数据就已经渲染出来了,说明是在这个方法里进行的解密 解密方法,data.result为加密数据,data.s…

MyBatis 源码解析:Mapper 文件加载与解析

引言 在 MyBatis 中,Mapper 文件扮演了至关重要的角色,它通过 SQL 映射文件来定义数据库查询操作和 Java 对象之间的映射关系。Mapper 文件通常是以 XML 格式存储的,包含了 SQL 语句以及与 Java 对象的对应关系。在本篇文章中,我…

Java异常架构与异常关键字

1. Java异常简介 Java 异常是 Java 提供的一种识别及响应错误的一致性机制。 Java 异常机制可以使程序中异常处理代码和正常业务代码分离,保证程序代码更加优雅,并提高程 序健壮性。在有效使用异常的情况下,异常能清晰的回答 what, where,…

常见的限流算法

限流算法是用于控制访问频率、保护系统免受过载攻击的重要手段。常见的限流算法有以下几种,每种算法都有不同的应用场景和优缺点。下面是几种常见的限流算法的详细介绍: 1. 计数器算法(Counter) 原理 计数器算法是最简单的限流…