Attention的总结

embedded/2024/10/22 12:23:21/

文章目录

  • Attention 机制
    • 计算过程
    • 实现
      • 维度设置
      • 具体维度
      • 多维张量乘法的基本规则
      • 广播的应用:
      • code
    • 矩阵乘法选择
      • 1. `torch.bmm()`(Batch Matrix Multiplication)
      • 2. `torch.matmul()`(Matrix Multiplication)
      • 3. `np.dot()`(NumPy Dot Product)
  • Multi-head Attention
    • 原理
    • 实现
  • Self-Attention
    • 计算过程
    • 代码
  • Mask 机制

Attention 机制

计算过程

有三个关键变量:

  • Q(uery):查询变量
  • K(ey):键
  • V(alue):值

这三个变量会作为输入用于Attention的计算。过程为:
在这里插入图片描述

  1. 模型使用Q与所有的K进行相似度计算(如点积等等)
  2. 为了避免点积的结果过大,softmax的效果差,我们将点积的结果和一个缩放因子(通常是键向量维度的平方根的倒数)进行缩放
  3. 将点积结果通过softmax函数转变成相似度的概率分布(即权重)
    4.将权重和值V进行加权相乘,获取结果输出

实现

维度设置

  • batch_size (B):批次大小,指一次处理的数据样本数量。
  • sequence_length (S):序列长度,如一句话中的单词数或一段时间序列的长度。
  • embedding_dimension (D):嵌入维度,指 Q、K、V 经过嵌入层或线性变换后的特征数量。

具体维度

  • Q 的维度:([B, S, D_q]),其中 (D_q) 是查询向量的维度。
  • K 的维度:([B, S, D_k]),其中 (D_k) 是键向量的维度。
  • V 的维度:([B, S, D_v]),其中 (D_v) 是值向量的维度。

其中D_k = D_q, N_k = N_v,别的不一定相等

多维张量乘法的基本规则

在这里插入图片描述

广播的应用:

在多维张量乘法中,广播规则允许操作两个形状不完全匹配的张量。假设你想乘的两个张量有不同的维度数量:

  • 如果两个张量的维度数不同,较小维度的张量将会在其较高维度前补充1,直到两个张量的维度数相同。
  • 从最后一个维度开始比较,每个维度要么相同,要么其中之一是1。

code

import numpy as np
import torch
from torch import nnclass ScaledDotProductAttention(nn.Module):# scale:缩放因子def __init__(self, scale):super().__init__()self.scale = scaleself.softmax = nn.Softmax(dim=2)def forward(self, q, k, v):# 1. q和k进行点积计算t = torch.matmul(q, k.transpose(1, 2))# 2.缩放t = t * self.scale# 3.softmax层attention = self.softmax(t)# 4.获取输出output = torch.bmm(attention, v)return attention, outputif __name__ == "__main__":n_q, n_k, n_v = 2, 4, 4d_q, d_k, d_v = 128, 128, 64batch = 64q = torch.randn(batch, n_q, d_q)k = torch.randn(batch, n_k, d_k)v = torch.randn(batch, n_v, d_v)attention = ScaledDotProductAttention(scale=np.power(d_k, 0.5))attn, output = attention(q, k, v)print(attn)print(output)

矩阵乘法选择

1. torch.bmm()(Batch Matrix Multiplication)

  • 用途:用于批量矩阵乘法。
  • 输入要求:输入张量必须是三维的,形状为 [batch_size, n, m][batch_size, m, p]
  • 输出:输出形状为 [batch_size, n, p],即每个批次的矩阵乘积结果。
  • 场景:对一组矩阵进行相同的矩阵乘法操作时使用,每个批次独立处理,常见于神经网络中处理多个数据样本的场景。

2. torch.matmul()(Matrix Multiplication)

  • 用途:更通用的矩阵乘法函数,可以处理两个张量的乘法,支持广播。
  • 输入要求:可以是两个任意维度的张量。对于两个1D张量,进行内积;对于2D张量,进行标准的矩阵乘法;对于高于2D的张量,执行批量矩阵乘法,最后两维进行矩阵乘法,前面的维度进行广播。
  • 输出:根据输入的维度,可能是标量、向量或矩阵。
  • 场景:需要处理不同维度或需要广播支持的复杂矩阵乘法时使用,非常灵活。

3. np.dot()(NumPy Dot Product)

  • 用途:NumPy 中的点积函数,行为取决于输入数组的维度。
  • 输入要求:可以是1D或2D数组,对于1D数组执行向量内积,对于2D数组执行矩阵乘法。
  • 输出:根据输入维度,可能是标量或矩阵。
  • 场景:在使用 NumPy 处理向量和矩阵运算时使用。注意,np.dot() 在面对高于二维的数组时不支持广播,这与 torch.matmul() 不同。

Multi-head Attention

原理

单头注意力指的就是注意力过程只求一次,即QKV进行上述过程一次;多头则是指求多次。
实现方式为:
将原始的QKV通过线性变化得到多组的QKV,这些组将从不同的空间来“理解”任务。
比如现在原始Q=“勺子”,线性变换后得到两个头部——餐具和食物,这样在第一个头部,就会注意到和筷子、刀叉等餐具之间的高相似度,在第二个头部,就会注意到和汤类食物的高相似度,这样借助多组QKV在不同的子空间上去理解原始输入的语义了。

实现

import numpy as np
import torch
from torch import nnfrom attention.Attention import ScaledDotProductAttentionclass MultiHeadAttention(nn.Module):"""n_head:分成几个头部d_k_:输入的Q和K的维度d_v_:输入的V的维度d_k:变换后Q和K的维度d_v:变换后V的维度d_o:输出维度"""def __init__(self, n_head, d_k_, d_v_, d_k, d_v, d_o):super().__init__()self.n_head = n_headself.d_k = d_kself.d_v = d_v# 将原始的维度进行多头变换,即映射到多个子空间上self.fc_q = nn.Linear(d_k_, n_head * d_k)self.fc_k = nn.Linear(d_k_, n_head * d_k)self.fc_v = nn.Linear(d_v_, n_head * d_v)# 每个头单独进行注意力计算self.attention = ScaledDotProductAttention(scale=np.power(d_k, 0.5))# 输出层,将多头注意力拼接起来self.fc_o = nn.Linear(n_head * d_v, d_o)def forward(self, q, k, v):n_head, d_q, d_k, d_v = self.n_head, self.d_k, self.d_k, self.d_v#获取参数batch, n_q, d_q_ = q.size()batch, n_k, d_k_ = k.size()batch, n_v, d_v_ = v.size()# 1.扩展成多头q = self.fc_q(q)k = self.fc_k(k)v = self.fc_v(v)"""q:[batch, n_q, n_head * d_q]view:重新塑性 q的维度=>[batch, n_q, n_head, d_q]permute:置换维度,即调整张量顺序,将原始的维度移到目标位置上去[n_head, batch, n_q, d_q]contiguous():由于permute操作可能会导致张量在内存中的存储不连续,使用.contiguous()确保张量在内存中连续存储view(-1, n_q, d_q):[n_head, batch, n_q, d_q] => [n_head * batch, n_q, d_q]最原始数据:假设batch = 2. n_head = 4Q:[数据1, 数据2, 数据3][数据4, 数据5, 数据6]变换后的:头1批次1: [数据1头1部分, 数据2头1部分, 数据3头1部分]头1批次2: [数据4头1部分, 数据5头1部分, 数据6头1部分]...头4批次1: [数据1头4部分, 数据2头4部分, 数据3头4部分]头4批次2: [数据4头4部分, 数据5头4部分, 数据6头4部分]"""q = q.view(batch, n_q, n_head, d_q).permute(2, 0, 1, 3).contiguous().view(-1, n_q, d_q)k = k.view(batch, n_k, n_head, d_k).permute(2, 0, 1, 3).contiguous().view(-1, n_k, d_k)v = v.view(batch, n_v, n_head, d_v).permute(2, 0, 1, 3).contiguous().view(-1, n_v, d_v)# 2.当成单头注意力求输出attn, output = self.attention(q, k, v)# 3.拼接多头的输出"""output:[n_head * batch, n_q, d_q]view(n_head, batch, n_q, d_v):[n_head, batch, n_q, d_q]permute(1, 2, 0, 3):[batch, n_q, n_head, d_q]view(batch, n_q, -1):[batch, n_q, n_head * d_v]作用:将多头的输出拼接起来"""output = output.view(n_head, batch, n_q, d_v).permute(1, 2, 0, 3).contiguous().view(batch, n_q, -1)# 4.仿射变换得到最终输出output = self.fc_o(output)return attn, outputif __name__ == "__main__":n_q, n_k, n_v = 2, 4, 4d_q_, d_k_, d_v_ = 128, 128, 64batch = 5q = torch.randn(batch, n_q, d_q_)k = torch.randn(batch, n_k, d_k_)v = torch.randn(batch, n_v, d_v_)mask = torch.zeros(batch, n_q, n_k).bool()mha = MultiHeadAttention(n_head=8, d_k_=128, d_v_=64, d_k=256, d_v=128, d_o=128)attn, output = mha(q, k, v, mask=mask)print(attn.size())print(output.size())

Self-Attention

计算过程

Q, K, V全来自于原始内容X
Q = XWq,其它同理,加了个线性变化

为什么要对原始的X进行线性变化得到新QKV,而不是直接得到呢?
答:因为如果我们不对X进行变化,会导致注意力机制就只能从输入数据的固定表示中学习,限制了模型捕捉复杂依赖关系的能力。
我们加入了可以学习的Wq、Wk和Wx,这样就可以在更多的空间上去捕捉新的特征,而不是局限于输入向量的固定表示。

代码

import numpy as np
import torch
from torch import nnfrom attention.Attention import ScaledDotProductAttention
from attention.MultiHeadAttention import MultiHeadAttentionclass SelfAttention(nn.Module):def __init__(self, n_head, d_k, d_v, d_x, d_o):"""nn.Parameter:tensor的一个子类, 默认会将 requires_grad=True主要作用:表示该张量是一个可以训练的参数,会被加到self.parameters()中"""self.wq = nn.Parameter(torch.Tensor(d_x, d_k))self.wk = nn.Parameter(torch.Tensor(d_x, d_k))self.wv = nn.Parameter(torch.Tensor(d_x, d_v))self.mha = MultiHeadAttention(n_head=n_head, d_k_=d_k, d_v_=d_v, d_k=d_k, d_v=d_v, d_o=d_o)self.ha = ScaledDotProductAttention(scale=np.power(d_k, 0.5))self.init_parameters()# 初始化Wq,Wk,Wxdef init_parameters(self):for param in self.parameters():stdv = 1. / np.power(param.size(-1), 0.5)param.data.uniform_(-stdv, stdv)def forward(self, x):# 得到初始化的QKVq = torch.matmul(x, self.wq)k = torch.matmul(x, self.wk)v = torch.matmul(x, self.wv)# 进行自注意力计算,使用多头attn, output = self.mha(q, k, v)# 使用单头# attn, output = self.ha(q, k, v)return attn, outputif __name__ == "__main__":n_x = 4d_x = 80batch = 64x = torch.randn(batch, n_x, d_x)selfattn = SelfAttention(n_head=8, d_k=128, d_v=64, d_x=80, d_o=80)attn, output = selfattn(x)print(attn.size())print(output.size())

Mask 机制

主要有两种:

  • 未来掩码:屏蔽未来的信息,常见于生成式的模型
  • 填充掩码:屏蔽掉用于padding的无关元素,避免没有实际意义的score被学习到

通过Q和K计算得来的Score矩阵是Mask的关键,
未来掩码采用的是遮掉主对角线以上的元素,即(i <= j)的那部分;
填充掩码是将输入中<pad>元素的位置进行标记,得到标记矩阵,然后根据这个矩阵来遮掉一些Score;


http://www.ppmy.cn/embedded/127635.html

相关文章

如何解锁业务数据价值:基于云器Lakehouse构建面向未来的ELT现代数据栈

随着企业数据量和数据架构的日益复杂&#xff0c;企业急需寻找更快、更高效、更节省成本的数据管理和分析方法。近年来&#xff0c;现代化数据栈&#xff08;Modern Data Stack&#xff0c;简称MDS&#xff09;不断创新发展&#xff0c;受到了广泛的关注。 MDS是什么&#xff…

抖去推--短视频矩阵系统源码对外资料包

#短视频矩阵系统源码# #短视频矩阵系统源码开发# #短视频矩阵系统源码打包# 一、短视频矩阵系统源码安装 安装环境 短视频矩阵系统源码需要以下环境&#xff1a; PHP 7.0 及以上 MySQL 5.5 及以上 Nginx / Apache Redis FFMpeg 下载源码 从官网下载最新版本的短视频矩阵系统…

计算图形学-自学:几何体数据结构

1.参考文章&#xff1a;几何体数据结构学习记录 - 知乎 (zhihu.com) 2.我目前特别想要学的是四叉树和八叉树

【VUE】Vue2中如何监听(检测)对象或者数组某个属性的变化

当在项目中直接设置数组的某一项的值&#xff0c;或者直接设置对象的某个属性值&#xff0c;这个时候&#xff0c;你会发现页面并没有更新。这是因为Object.defineProperty()限制&#xff0c;监听不到变化。 解决方式&#xff1a; this.$set(你要改变的数组/对象&#xff0c;…

YOLOv11模型地址

地址链接 项目Git地址&#xff1a;https://github.com/ultralytics/ultralytics?tabreadme-ov-file

SpringBoot教程(二十四) | SpringBoot实现分布式定时任务之Quartz(基础)

SpringBoot教程&#xff08;二十四&#xff09; | SpringBoot实现分布式定时任务之Quartz&#xff08;基础&#xff09; 简介适用场景Quartz核心概念Quartz 存储方式Quartz 版本类型引入相关依赖开始集成方式一&#xff1a;内存方式(MEMORY)存储实现定时任务1. 定义任务类2. 定…

技术总结(四)

一、什么是进程和线程,他们之前区别是什么? 进程&#xff08;Process&#xff09; 是指计算机中正在运行的一个程序实例。举例&#xff1a;你打开的微信就是一个进程。线程&#xff08;Thread&#xff09; 也被称为轻量级进程&#xff0c;更加轻量。多个线程可以在同一个进程…

QT:信号与槽

QT是一种流行的C框架&#xff0c;用于开发图形用户界面&#xff08;GUI&#xff09;应用程序。在QT中&#xff0c;信号与槽是一种用于实现对象间通信的机制。 信号&#xff08;signal&#xff09;是对象发出的事件或消息&#xff0c;槽&#xff08;slot&#xff09;是接收并处…