pytorch torch.gather函数介绍

embedded/2024/10/22 13:45:00/

torch.gather 是 PyTorch 中的一个用于从给定维度上按索引取值的函数。它根据一个索引张量 index,从源张量 input 中收集值,并返回一个新的张量。torch.gather 常用于需要从张量的特定位置抽取元素的操作。

1. 函数签名

torch.gather(input, dim, index, *, sparse_grad=False, out=None)
  • input:输入张量,表示要从中收集元素的源张量。
  • dim:要收集的维度索引。例如,对于一个二维张量,0 表示沿着行的维度,1 表示沿着列的维度。
  • index:索引张量,其形状应与input张量在除了dim维度之外的其他维度上保持一致。索引张量中的值表示在input张量对应维度上要收集的元素的索引。
  • out(可选):输出张量,如果提供,结果将存储在这个张量中。

2. 工作原理

torch.gather 在 dim 维度上,通过 index 指定的索引,从 input 中选取元素。 返回的张量的形状与 index 的形状相同。

3. 示例代码

以下是一个简单的示例代码,演示如何使用 torch.gather 函数:

import torch# 创建一个源张量
input = torch.tensor([[1, 2, 3],[4, 5, 6],[7, 8, 9]])# 创建一个索引张量
index = torch.tensor([[0, 2, 1],[2, 0, 1],[1, 2, 0]])# 在 dim=1 维度上使用 gather 函数
result = torch.gather(input, dim=1, index=index)print("Input Tensor:")
print(input)
print("\nIndex Tensor:")
print(index)
print("\nResult Tensor:")
print(result)

4. 输出结果

Input Tensor:
tensor([[1, 2, 3],[4, 5, 6],[7, 8, 9]])Index Tensor:
tensor([[0, 2, 1],[2, 0, 1],[1, 2, 0]])Result Tensor:
tensor([[1, 3, 2],[6, 4, 5],[8, 9, 7]])

5. 解释

  • 输入张量 (input) 是一个 3x3 的矩阵,每个元素代表一个值。
  • 索引张量 (index) 指定了要从 input 中提取的元素的索引。
  • 结果张量 (result) 是根据 index 从 input 中提取的元素形成的张量。

在这个例子中:

  • 对于 input 的第一行,index 提取了索引 0, 2, 1 对应的元素 1, 3, 2
  • 对于 input 的第二行,index 提取了索引 2, 0, 1 对应的元素 6, 4, 5
  • 对于 input 的第三行,index 提取了索引 1, 2, 0 对应的元素 8, 9, 7

6. 总结

  • torch.gather 通过索引在指定维度上提取张量中的元素,是用于基于索引选择数据的有用工具。
  • 函数对批处理数据特别有用,例如在分类任务中提取对应类别的概率或得分。
  • 索引张量的形状必须与源张量在指定维度的形状相匹配,以确保正确的取值操作。


http://www.ppmy.cn/embedded/108746.html

相关文章

# Windows下配置Redis以服务方式启动

Windows下配置Redis以服务方式启动 Redis以服务方式启动 winR快捷键打开运行窗口,输入cmd进入 DOS窗口。进入redis的安装目录。安装redis服务 , 输入命令 redis-server --service-install redis.windows.conf --loglevel verbose 启动服务&#xff0c…

第十七题:电话号码的字母组合

题目描述 给定一个仅包含数字 2-9 的字符串,返回所有可能的由它组成的字母组合。你可以假设输入字符串至少包含一个数字,并且不超过3位数字。 实现思路 使用哈希表或数组存储每个数字对应的字符,然后通过递归或迭代的方式生成所有可能的组…

快速失败 (fail-fast) 和安全失败 (fail-safe)

1. 定义与工作原理 1.1 快速失败(Fail-Fast) 定义: 快速失败是一种系统设计原则,当系统遇到异常情况或错误时,立即停止执行并返回错误,而不是试图继续执行或处理潜在的问题。快速失败系统会主动检测系统中…

基于Tomcat的JavaWeb(ASP)项目构建(图解)

目录 配置IDEA的TOMCAT环境 环境设置 导入API(可选) 创建项目 构建项目 ​编辑 运行项目 项目结果 ​编辑 查看配置基础项目 配置IDEA的TOMCAT环境 环境设置 导入API(可选) 创建项目 构建项目 运行项目 项目结果 查看配置基础项目 了解Web Application: Exploded与…

【免费分享】高斯过程回归(Gaussian process regression)原理详解及MATLAB代码实战

MATLAB实战 net fitrgp(p_train, t_train, KernelFunction, ardsquaredexponential, ...Optimizer, lbfgs, KernelParameters, [sigmaL0; sigmaF0], Sigma, sigmaN0);fitrgp 函数来训练一个 高斯过程回归模型 (Gaussian Process Regression, GPR)。具体来说,它在训…

javaWeb【day04】--(MavenSpringBootWeb入门)

01. Maven课程介绍 1.1 课程安排 学习完前端Web开发技术后,我们即将开始学习后端Web开发技术。做为一名Java开发工程师,后端Web开发技术是我们学习的重点。 1.2 初识Maven 1.2.1 什么是Maven Maven是Apache旗下的一个开源项目,是一款用于…

AF透明模式/虚拟网线模式组网部署

透明模式组网 实验拓扑 防火墙基本配置 接口配置 eth1 eth3 放通策略 1. 内网用户上班时间(9:00-17:00)不允许看视频、玩游戏及网上购物,其余时 间访问互联网不受限制;(20 分) 应用控制策略 2. 互联…

【Webpack】基本使用方法

📢博客主页:逆旅行天涯-CSDN博客 📢欢迎点赞👍收藏⭐留言📝如有错误敬请指正! 参考视频: 30 分钟掌握 Webpack_哔哩哔哩_bilibili 什么是webpack 简单来说就是一个 打包工具, 可…