PyTorch torch.unbind、torch.split 和 torch.chunk函数介绍

devtools/2025/2/9 6:34:19/

pytorch torch.unbind、torch.split 和 torch.chunk等函数可用于张量的拆分操作。

1. torch.unbind

功能说明:
torch.unbind 沿指定的维度将张量“解包”为多个张量,返回一个元组。解包后被操作的那个维度会消失,每个输出张量的维度数会比原张量少 1。

函数签名:

torch.unbind(input, dim=0)
  • input: 待解包的张量。
  • dim: 指定沿哪个维度解包。

多维张量示例:

import torch# 创建一个形状为 (2, 3, 4) 的张量
x = torch.arange(24).reshape(2, 3, 4)
print("原始张量 x:\n", x)
print("x.shape:", x.shape)  # torch.Size([2, 3, 4])# 沿第 0 维解包
slices0 = torch.unbind(x, dim=0)
print("\n沿 dim=0 解包:")
for i, t in enumerate(slices0):print(f"slice {i} shape: {t.shape}")# 每个张量形状为 (3, 4)# 沿第 1 维解包
slices1 = torch.unbind(x, dim=1)
print("\n沿 dim=1 解包:")
for i, t in enumerate(slices1):print(f"slice {i} shape: {t.shape}")# 每个张量形状为 (2, 4)

2. torch.split

功能说明:
torch.split 根据给定的大小或尺寸列表,将张量沿指定维度切分成若干块。

  • 如果传入一个整数,则每块的大小为该整数,最后一块可能会小于这个整数。
  • 如果传入一个尺寸列表,则按列表中指定的尺寸进行切分。

函数签名:

torch.split(tensor, split_size_or_sections, dim=0)
  • tensor: 待分割的张量。
  • split_size_or_sections: 整数或尺寸列表,指定每块的大小。
  • dim: 指定沿哪个维度进行切分。

多维张量示例:

import torch# 创建一个形状为 (2, 5, 4) 的张量
x = torch.arange(40).reshape(2, 5, 4)
print("原始张量 x:\n", x)
print("x.shape:", x.shape)  # torch.Size([2, 5, 4])# 按照固定大小进行切分:沿第 1 维,每块大小为 2
splits_fixed = torch.split(x, 2, dim=1)
print("\n沿 dim=1 按固定大小 2 切分:")
for i, t in enumerate(splits_fixed):print(f"chunk {i} shape: {t.shape}")# 输出块的形状可能为 (2, 2, 4), (2, 2, 4) 和最后一块 (2, 1, 4)# 按照指定尺寸列表进行切分:沿第 1 维,分块尺寸为 [1, 2, 2]
splits_list = torch.split(x, [1, 2, 2], dim=1)
print("\n沿 dim=1 按尺寸列表 [1, 2, 2] 切分:")
for i, t in enumerate(splits_list):print(f"chunk {i} shape: {t.shape}")# 分别输出形状 (2, 1, 4), (2, 2, 4), (2, 2, 4)

3. torch.chunk

功能说明:
torch.chunk 将张量沿指定维度平均分成指定数量的块。如果张量在该维度上的长度不能被块数整除,则前面的块会比后面块多一个元素(块的尺寸差别最多为 1)。

函数签名:

torch.chunk(tensor, chunks, dim=0)
  • tensor: 待分割的张量。
  • chunks: 指定分成几块。
  • dim: 指定沿哪个维度进行分块。

多维张量示例:

对比总结

函数分割方式返回结果形式适用场景
torch.unbind沿指定维度将张量完全解包,每个输出不含该维度元组,输出张量数 = 该维度的长度需要逐个处理某一维度上的切片,且希望移除该维度时使用。
torch.split按照指定大小或尺寸列表切分张量元组或列表需要按固定大小或自定义尺寸列表切分张量,最后一块可能不均匀。
torch.chunk将张量均匀分成指定数量的块元组或列表希望将张量平均分成若干块,块数固定,自动处理无法整除的情况。

注意:

  • 当处理多维张量时,选择沿哪一维进行分割非常重要;
  • torch.unbind 会移除分割的那个维度,而 torch.split 和 torch.chunk 则保持原始维度,只是该维度上的大小发生变化。

通过这些示例代码和说明,你可以根据具体需求选择合适的函数来分割多维张量。


http://www.ppmy.cn/devtools/157269.html

相关文章

为什么要设计DTO类/什么时候设置DTO类?

为什么设计DTO类? 例如:根据新增员工接口设计对应的DTO 前端传递参数列表: 思考:是否可以使用对应的实体类来接收呢? 注意:前端提交的数据和实体类中对应的属性差别比较大,所以自定义DTO类。 …

python基础入门:3.5实战:词频统计工具

Python词频统计终极指南:字典与排序的完美结合 import re from collections import defaultdictdef word_frequency_analysis(file_path, top_n10):"""完整的词频统计解决方案:param file_path: 文本文件路径:param top_n: 显示前N个高频词:return:…

【前端】【Ts】【知识点总结】TypeScript知识总结

一、总体概述 TypeScript 是 JavaScript 的超集,主要通过静态类型检查和丰富的类型系统来提高代码的健壮性和可维护性。它涵盖了从基础数据类型到高级类型、从函数与对象的类型定义到类、接口、泛型、模块化及装饰器等众多知识点。掌握这些内容有助于编写更清晰、结…

React 设计模式:实用指南

React 提供了众多出色的特性以及丰富的设计模式,用于简化开发流程。开发者能够借助 React 组件设计模式,降低开发时间以及编码的工作量。此外,这些模式让 React 开发者能够构建出成果更显著、性能更优越的各类应用程序。 本文将会为您介绍五…

如何在WPS和Word/Excel中直接使用DeepSeek功能

以下是将DeepSeek功能集成到WPS中的详细步骤,无需本地部署模型,直接通过官网连接使用:1. 下载并安装OfficeAI插件 (1)访问OfficeAI插件下载地址:OfficeAI助手 - 免费办公智能AI助手, AI写作,下载…

Spring Boot统一异常拦截实践指南

Spring Boot统一异常拦截实践指南 一、为什么需要统一异常处理 在Web应用开发中,异常处理是保证系统健壮性和用户体验的重要环节。传统开发模式中常见的痛点包括: 异常处理逻辑分散在各个Controller中错误响应格式不统一敏感异常信息直接暴露给客户端…

tcpdump能否抓到被iptable封禁的包

tcpdump 能否抓到被 iptable 封禁的包? tcpdump工作在设备层,将包送到IP层以前就能处理。而netfiter工作在IP、ARP等层。从图2.13收包流程处理顺序上来看,netfiter是在tcpdump后面工作的,所以iptable封禁规则影响不到tcpdump的抓包。 不过发…

Effective Objective-C 2.0 读书笔记—— 接口与API设计

Effective Objective-C 2.0 读书笔记—— 接口与API设计 文章目录 Effective Objective-C 2.0 读书笔记—— 接口与API设计1. 用前缀避免命名空间冲突2.提供"全能初始化方法"3.实现description方法4.尽量使用不可变对象5.理解Objective -C错误模型 1. 用前缀避免命名…