注意力机制:让AI拥有“黄金七秒记忆“的魔法----(点积注意力)

ops/2025/3/16 16:10:31/

注意力机制:让AI拥有"黄金七秒记忆"的魔法–(点积注意力)

注意⼒机制对于初学者来说有点难理解,我们⼀点⼀点地讲。现在先暂时忘记编码器、解码器、隐藏层和序列到序列这些概念。想象我们有两个张量x1和x2,我们希望⽤注意⼒机制把它俩给衔接起来,让x1看⼀看,x2有哪些特别值得关注的地⽅。

具体来说,要得到x1对x2的点积注意⼒,我们可以按照以下步骤进⾏操作。

(1)创建两个形状分别为(batch_size, seq_len1, feature_dim)和(batch_size, seq_len2, feature_dim)的张量x1和x2。

(2)将x1中的每个元素和x2中的每个元素进⾏点积,得到形状为 (batch_size, seq_len1, seq_len2)的原始权重raw_weights

(3)⽤softmax函数对原始权重进⾏归⼀化,得到归⼀化后的注意⼒权重attn_weights(注意⼒权重的值在0和1之间,且每⼀⾏的和为1),形状仍为 (batch_size, seq_len1, seq_len2)。

(4)⽤注意⼒权重attn_weights对x2中的元素进⾏加权求和(与x2相乘),得到输出张量y,形状为 (batch_size, seq_len1, feature_dim)。这就是x1对x2的点积注意⼒。

程序结构如下:

image-20250314201213369

一、点积注意机制

import torch # 导入 torch
import torch.nn.functional as F # 导入 nn.functional
# 1. 创建两个张量 x1 和 x2
x1 = torch.randn(2, 3, 4) # 形状 (batch_size, seq_len1, feature_dim)
x2 = torch.randn(2, 5, 4) # 形状 (batch_size, seq_len2, feature_dim)
# 2. 计算原始权重
raw_weights = torch.bmm(x1, x2.transpose(1, 2)) # 形状 (batch_size, seq_len1, seq_len2)
# 3. 用 softmax 函数对原始权重进行归一化
attn_weights = F.softmax(raw_weights, dim=2) # 形状 (batch_size, seq_len1, seq_len2)
# 4. 将注意力权重与 x2 相乘,计算加权和
attn_output = torch.bmm(attn_weights, x2)  # 形状 (batch_size, seq_len1, feature_dim)

1.1 创建两个张量x1和x2

# 创建两个张量 x1 和 x2
x1 = torch.randn(2, 3, 4) # 形状 (batch_size, seq_len1, feature_dim)
x2 = torch.randn(2, 5, 4) # 形状 (batch_size, seq_len2, feature_dim)
print("x1:", x1)
print("x2:", x2)

1.2 计算张量点积,得到原始权重

# 计算点积,得到原始权重,形状为 (batch_size, seq_len1, seq_len2)
raw_weights = torch.bmm(x1, x2.transpose(1, 2))
print(" 原始权重:", raw_weights) 

因为是对x1和x2的两个特征维度进行点积后归一化,所以要对x2数组进行转置。

image-20250314221502682

⽐如,输出结果的第⼀⾏[ 1.2474, -0.6254, 1.4849, 2.9333, -0.1787]就代表着本批次第⼀个x1序列中第⼀个元素(每个x1序列有3个元素,所以第⼀批次共3⾏)与x2中第⼀批次5个元素的每⼀个元素的相似度得分(不难看出,x1中第⼀个元素与x2中第4个元素最相似,原始注意⼒分值为2.9333)。

相似度的计算是注意⼒机制最核⼼的思想

因为点积其实可以一定程度上反应向量方向的相关性,所以通过将x1的元素与x2的各个元素进行点积就可以求出权重(原始得分)其中在x2中权重较大的(得分高的)对应的词即为与x1中相关度最高的,故x2可以根据这个原始的分来判断应该输出那些对应x1相关性最高的内容。

x1起到编译器,x2起到译码器的作用

在某些⽂献或代码中,有时会将相似度得分称为原始权重。这是因为它们实际上是在计算注意⼒权重之前的中间结果。严格来说,相似度得分表示输⼊序列中不同元素之间的关联性或相似度,⽽权重则是在应⽤某些操作(如缩放、掩码和归⼀化)后得到的归⼀化值。为了避免混淆,可以将这两个术语彻底区分开。

通常,将未处理的值称为得分,并在经过处理后将它们称为权重。这有助于更清晰地理解注意⼒机制的⼯作原理及其不同组件。

举一个栗子:

让我们⽤下⾯的图示来对向量点积和相似度得分进⾏相对直观的理解。在下图的例⼦中,有两个向量——电影的特征(M)和⽤户的兴趣(U

image-20250314222742613

向量U中可能蕴含⽤户是否喜欢爱情⽚、喜欢动作⽚等信息;⽽向量M中则包含电影含有动作、浪漫等特征的程度。

通过计算UM的点积或相似度得分,我们可以得到⼀个衡量UM兴趣程度的分数。例如,如果向量U中喜欢爱情⽚、喜欢动作⽚的权重较⾼,⽽向量M中的动作和浪漫特征的权重也较⾼,那么计算得到的点积或相似度得分就会⽐较⾼,表示UM的兴趣较⼤,系统有可能推荐这部电影给⽤户。

1.3 对原始权重进行归一化

import torch.nn.functional as F # 导入 torch.nn.functional
# 应用 softmax 函数,使权重的值在 0 和 1 之间,且每一行的和为 1
attn_weights = F.softmax(raw_weights, dim=-1) # 归一化
print(" 归一化后的注意力权重:", attn_weights)

所谓的归⼀化,其实理解起来很简单。得到每⼀个x1序列中的元素与其所对应的5个x2序列元素的相似度得分后,使⽤softmax函数进⾏缩放,让这5个数加起来等于1。

image-20250314223142527

 归一化后的注意力权重: tensor([[[0.3154, 0.2383, 0.2145, 0.1589, 0.0729],[0.0015, 0.9234, 0.0090, 0.0015, 0.0645],[0.0533, 0.0576, 0.5788, 0.0858, 0.2245]],[[0.4959, 0.0374, 0.1558, 0.0349, 0.2760],[0.0034, 0.0470, 0.0424, 0.8826, 0.0246],[0.2597, 0.0678, 0.0840, 0.1356, 0.4530]]])

归⼀化后,attn_weights(权重)和raw_weights(得分)形状相同,但是值变了,第⼀⾏的5个数字加起来刚好是1。第4个数字是0.6697,这就表明:在本批次的第⼀⾏数据中,x2序列中的第4个元素和x1序列的第1个元素特别相关,应该加以注意。

1.4 求出注意力机制的加权和

注意⼒权重与x2相乘,就得到注意⼒分布的加权和

换句话说,我们将x2中的每个位置向量乘以它们在x1中对应位置的注意⼒权重,然后将这些加权向量求和——这是点积注意⼒计算的最后⼀个环节。这⼀步的⽬的是根据注意⼒权重计算x2的加权和。这个加权和才是x1对x2的注意⼒输出

加权只是对应着一个关系表,并不代表输出。

相当于在一个函数中,已经求得了对应关系,现在需要给一个输入,才能得出一个输出值。

# 与 x2 相乘,得到注意力分布的加权和,形状为 (batch_size, seq_len1, feature_dim)
attn_output = torch.bmm(attn_weights, x2)
print(" 注意力输出 :", attn_output)

image-20250314224208498

en1, feature_dim)
attn_output = torch.bmm(attn_weights, x2)
print(" 注意力输出 :", attn_output)

[外链图片转存中...(img-ujffJwJa-1742036282108)]

http://www.ppmy.cn/ops/166239.html

相关文章

Spring 事务失效的 8 种场景!

在日常工作中,如果对Spring的事务管理功能使用不当,则会造成Spring事务不生效的问题。而针对Spring事务不生效的问题,也是在跳槽面试中被问的比较频繁的一个问题。 点击上方卡片关注我 今天,我们就一起梳理下有哪些场景会导致Sp…

3. 无重复字符的最长子串

给定一个字符串 s ,请你找出其中不含有重复字符的 最长 子串 的长度。 示例 1: 输入: s "abcabcbb" 输出: 3 解释: 因为无重复字符的最长子串是 "abc",所以其长度为 3。示例 2: 输入: s "bbbbb" 输出: 1 解释: 因为无…

Linux内核,mmap_pgoff在mmap.c的实现

1. mmap_pgoff的系统调用实现如下 SYSCALL_DEFINE6(mmap_pgoff, unsigned long, addr, unsigned long, len,unsigned long, prot, unsigned long, flags,unsigned long, fd, unsigned long, pgoff) {return ksys_mmap_pgoff(addr, len, prot, flags, fd, pgoff); }2. ksys_mma…

MyBatis框架操作数据库一>xml和动态Sql

目录 配置连接字符串和MyBatis:数据库的连接配置:XML的配置: XML编写Sql:model层:mapper层: 动态Sql:if 标签和trim标签:where标签:Set标签:Foreach标签: Mybatis的开发有两种方式:: 注解和XML&…

FPGA前端设计适合哪些人学?该怎么学?

FPGA前端设计是一个具有挑战性且薪资待遇优渥的岗位,主要涉及FPGA芯片定义、逻辑结构设计。这个职位要求相关专业的本科及以上学历,并且需要掌握一定的专业技能。工作内容从IP级设计到全芯片(SoC)设计,涉及多个设计层级…

iOS应用程序开发(图片处理器)

续上篇 iOS 编程开发图片浏览器,继续实现一个图标生成功能。 功能。 操作系统平台:MacBook(macOS) IDE:xcode 编程语言:Objective-c 以下是小程序在 iPhone 模拟器中的运行视频。也可以在 iPad 模拟器中运行。 效果图如下所示&#xff1a…

Linux find 命令完全指南

find 是 Linux 系统最强大的文件搜索工具&#xff0c;支持 嵌套遍历、条件筛选、执行动作。以下通过场景分类解析核心用法&#xff0c;涵盖高效搜索、文件管理及高级技巧&#xff1a; 一、基础搜索模式 1. 按文件名搜索&#xff08;精确/模糊匹配&#xff09; <BASH> f…

洛谷 P1725 琪露诺 单调队列优化的线性dp

以上是题目 考虑到2e5的数据范围&#xff0c;暴力的先枚举i&#xff0c;在枚举走的步数区间j&#xff0c;是过不了的&#xff0c; 我们可以看出对于每一个i&#xff0c;只需要找出能走的i的区间的dp最大值即可&#xff0c;求区间最大值可以使用单调队列&#xff0c;时间复杂度…