PyTorch gather与scatter_详解

devtools/2024/10/22 5:00:00/

PyTorch gather与scatter_详解

在 PyTorch 常用的算子中,有两个理解巅峰的存在,那就是 torch.gathertorch.scatter_,在 Seq2SeqAttentioncrf viterbi等结构的源码中,都可以看到这两个算子的身影,今天来详细讲解一下这两个函数。

torch.gather

使用

torch.gather 函数用于从输入张量的指定维度收集元素。收集的索引由 index 张量提供。

使用语法:torch.gather(input, dim, index, *, sparse_grad=False, out=None) → Tensor

核心参数

  • input:输入张量
  • dim:指定的维度
  • index:索引张量,包含收集元素的索引

注意

  • inputindex 必须要有相同的维度
  • 对于所有的 d != dim,都必须要有 index.size(d) <= input.size(d)以及out 的形状和 index形状相同
  • inputindex 之间没有广播机制
  • 只有在 src.shape == index.shape 时实现了反向传播
说明

以一个三维的张量为例

out[i][j][k] = input[index[i][j][k]][j][k]  # if dim == 0
out[i][j][k] = input[i][index[i][j][k]][k]  # if dim == 1
out[i][j][k] = input[i][j][index[i][j][k]]  # if dim == 2

看一个示意图(这里index和dim都是从1开始,转换成代码时 -1 即可
img

再看一个示意图,应该懂了

  • dim=0
    img
dim = 0
input = torch.tensor([[10, 11, 12], [13, 14, 15], [16, 17, 18]])
index = torch.tensor([[0, 1, 2], [1, 2, 0]])
# 将 index 的 dim=0 处固定 然后其他位置按顺序填充
# [['0'-0, '1'-1, '2'-2],  ['1'-0, '2'-1, '0'-2]]
# [[(0, 0), (1, 1), (2, 2)], [(1, 0), (2, 1), (0, 2)]]output = torch.gather(input, dim, index)
# tensor([[10, 14, 18],
#         [13, 17, 12]])
  • dim=1
    img
dim = 1
input = torch.tensor([[10, 11, 12], [13, 14, 15], [16, 17, 18]])
index = torch.tensor([[0, 1], [1, 2], [2, 0]])
# 将 index 的 dim=1 处固定 然后其他位置按顺序填充
# [[0-'0', 0-'1'], [1-'1', 1-'2'], [2-'2', 2-'0']]
# [[(0, 0), (0, 1)], [(1, 1), (1, 2)], [(2, 2), (2, 0)]]output = torch.gather(input, dim, index)
# tensor([[10, 11],
#         [14, 15],
#         [18, 16]])
案例

假设我们有一个 2D 张量 data,我们希望根据索引张量 indexdata 中提取特定位置的值。

import torch# 创建一个 2D 张量 data
data = torch.tensor([[1, 2, 3],[4, 5, 6],[7, 8, 9]])
print("Data tensor:")
print(data)# 创建一个索引张量 index
index = torch.tensor([[0, 2],[1, 0],[2, 1]])
print("\nIndex tensor:")
print(index)# 使用 gather 函数
result = torch.gather(data, 1, index)
print("\nGathered result:")
print(result)

我们对上面案例进行逐步解释

  1. 初始张量 data:

    data = torch.tensor([[1, 2, 3],[4, 5, 6],[7, 8, 9]])
    

    这是一个 3x3 的张量:

    tensor([[1, 2, 3],[4, 5, 6],[7, 8, 9]])
    
  2. 索引张量 index:

    index = torch.tensor([[0, 2],[1, 0],[2, 1]])
    

    这是一个 3x2 的张量,表示要从 data 中提取的索引。

  3. 使用 gather 函数:

    result = torch.gather(data, 1, index)
    

    这个操作会根据 index 张量中的索引,从 data 张量中提取相应位置的值。具体操作如下:

    • 对于 data 的第 0 行:

      • index[0, 0] = 0,所以 result[0, 0] = data[0, 0] = 1
      • index[0, 1] = 2,所以 result[0, 1] = data[0, 2] = 3
    • 对于 data 的第 1 行:

      • index[1, 0] = 1,所以 result[1, 0] = data[1, 1] = 5
      • index[1, 1] = 0,所以 result[1, 1] = data[1, 0] = 4
    • 对于 data 的第 2 行:

      • index[2, 0] = 2,所以 result[2, 0] = data[2, 2] = 9
      • index[2, 1] = 1,所以 result[2, 1] = data[2, 1] = 8

最终,result 张量为:

tensor([[1, 3],[5, 4],[9, 8]])

torch.scatter_

使用

torch.scatter_ 是 PyTorch 中一个用于在特定维度上根据索引将值写入张量的原地操作函数。

使用语法:Tensor.scatter_(dim, index, src, *, reduce=None) → Tensor

核心参数

  • dim:指定沿着哪个维度进行散射操作
  • index:一个包含索引的张量,指定 src 中的值要写入 tensor 的位置
  • src:包含要写入 tensor 的值的张量

注意

  • self, indexsrc必须有相同的维度
  • 对于所有的维度 d 必须有 index.size(d) <= src.size(d)以及index.size(d) <= self.size(d)
  • indexsrc 不会进行广播
说明

torch.scatter_ 其实就是torch.gather 的一个逆运算

以一个三维的张量为例

self[index[i][j][k]][j][k] = src[i][j][k]  # if dim == 0
self[i][index[i][j][k]][k] = src[i][j][k]  # if dim == 1
self[i][j][index[i][j][k]] = src[i][j][k]  # if dim == 2

另外,需要注意,scatter_ 是一个 inplace 算子

案例

先来看 dim=0 的情况

import torch
import numpy as np
src = torch.arange(1, 11).view(2, 5)
print(src)
> tensor([[ 1,  2,  3,  4,  5],[ 6,  7,  8,  9, 10]])input_tensor = torch.zeros(3, 5).long()
print(input_tensor)
> tensor([[0, 0, 0, 0, 0],[0, 0, 0, 0, 0],[0, 0, 0, 0, 0]])index_tensor = torch.tensor([[0, 1, 2, 0, 0], [2, 0, 0, 1, 2]])
print(index_tensor)
> tensor([[0, 1, 2, 0, 0],[2, 0, 0, 1, 2]])## try to manually work out the result 
dim = 0
input_tensor.scatter_(dim, index_tensor, src)
print(input_tensor)
> ...
  • step1:将 src 的第1列分散到 input _tensor 的第1列。与指数张量的第1列相匹配。我们把1分散到0排,6分散到2排。

img

  • step2:将 src 的第2列分散到 input _ tensor 的第2列。与指数张量第2列匹配。我们把2分散到第1排,把7分散到第0排。

img

  • step3/4/5:以此类推,继续对其他列做散射。最后,我们将得到如下图。

img

运行代码,检查最终结果

> tensor([[ 1,  7,  8,  4,  5],[ 0,  2,  0,  9,  0],[ 6,  0,  3,  0, 10]])

再来看 dim=1 的情况

origin data

import torchsrc = torch.arange(1, 11).view(2, 5)
input_tensor = torch.zeros(3, 5).long()
index_tensor = torch.tensor([[3, 0, 2, 1, 4], [2, 0, 1, 3, 1]])
dim = 1
input_tensor.scatter_(dim, index_tensor, src)
print(input_tensor)
  • step1:将 src 的第一行散布到 input _ tensor 的第一行。1到 col3,2到 col0,3到 col2,4到 col1,5到 col4。

img

  • step2:将 src 的第2行散布到 input _ tensor 的第2行。

注意:index _ tensor 的第二行有两个1。为了使更新更清晰,我将这一步分为两个子步骤。

  • step2.1:分散6到 col2,7到 col0,8到 col1,9到 col3。

img

  • step2.2:对10进行分散,相应的索引是1,但是该位置8已经存在了,我们需要用10来覆盖8。

img

运行代码,检查最终结果

> tensor([[ 2,  4,  3,  1,  5],[ 7, 10,  6,  9,  0],[ 0,  0,  0,  0,  0]])

参考

PyTorch torch.gather

PyTorch torch.scatter_

What does gather() do in PyTorch

Understand torch.scatter_()


http://www.ppmy.cn/devtools/125550.html

相关文章

【回顾原生JDBC手动管理事务以及两种方式实现Spring编程式事务】

文章目录 一.关于事务1.事务概念2.事务四个基本特性3. 事务的生命周期4.事务的隔离级别5.事务的应用场景 二.回顾原生JDBC手动管理事务三.Spring编程式事务1.使用 TransactionTemplate 进行编程式事务管理2.使用 PlatformTransactionManager 进行编程式事务管理 四.编程式事务的…

【重学 MySQL】六十五、auto_increment 的使用

【重学 MySQL】六十五、auto_increment 的使用 创建表时使用 AUTO_INCREMENT特点和要求插入数据查看当前 AUTO_INCREMENT 值设置初始 AUTO_INCREMENT 值重置 AUTO_INCREMENT 值注意事项示例&#xff1a;组合主键和 AUTO_INCREMENTMySQL8.0 新特性&#xff1a;自增变量的持久化背…

正点原子学习笔记之汇编LED驱动实验

1 汇编LED原理分析 为什么要写汇编     需要用汇编初始化一些SOC外设     使用汇编初始化DDR、I.MX6U不需要     设置sp指针&#xff0c;一般指向DDR&#xff0c;设置好C语言运行环境 1.1 LED硬件分析 可以看到LED灯一端接高电平&#xff0c;一端连接了GPIO_3上面…

栈溢出0x0C 前置技能:栈迁移

栈迁移原因&#xff1a; 在完成一般的栈溢出攻击时&#xff0c;有一个充分条件是「栈上有足够的地方让攻击者进行布局」。通常的函数栈剩余空间是足够放置一些恶意指令的&#xff0c;但也有少数极端情况&#xff0c;例如仅能容纳一个 ret与一个 ebp。此时&#xff0c;一般的栈…

蓝桥杯:求平均年龄

#include<stdio.h> int main() { int num 0; float age 0,sum0; printf("请输入总人数: "); scanf_s("%d" ,& num); for (int i1; i <num;i) { scanf_s("%f", &age); sum age…

【部署篇】Redis-04哨兵模式部署(源码方式安装)

一、准备主机 Redis的哨兵模式是生产环境中常用的部署模式之一&#xff0c;解决数据容灾和单点故障问题&#xff0c;实现主从自动切换&#xff1b;生产环境中建议让sentinel&#xff08;哨兵&#xff09;单独部署&#xff0c;如果资源有限可以和数据节点部署在同一主机。 主…

Java:数据结构-LinkedList和链表(2)

一 LinkedList LinkedList的方法的实现 1.头插法 public class MyLinkedList implements IList{static class ListNode{public int val;public ListNode next;public ListNode prev;public ListNode(int val){this.valval;}}public ListNode head;public ListNode last;Overr…

neovim ubuntu中WARNING No clipboard tool found

我在vnc远程的ubuntu中做个临时开发&#xff0c;发现neovim无法复制文字&#xff0c;于是我:checkhealth查看了一下&#xff0c;测试结果如下&#xff1a; WARNING No clipboard tool found. Clipboard registers (" and "*) will not work.ADVICE::help clipboard …