算法手撕面经系列(1)--手撕多头注意力机制

embedded/2024/9/22 7:40:28/

多头注意力机制

 一个简单的多头注意力模块可以分解为以下几个步骤:

  1. 先不分多头,对输入张量分别做变换,得到 Q , K , V Q,K,V Q,K,V
  2. 对得到的 Q , K , V Q,K,V Q,K,V按头的个数进行split;
  3. Q , K Q,K Q,K计算向量点积
  4. 考虑是否要添因果mask
  5. 利softmax计算注意力得分矩阵atten
  6. 对注意力得分矩阵施加Dropout
  7. 将atten矩阵和 V V V矩阵相乘
  8. 再过一道最终的输出变换

代码

 给出一个 d k = d v = d m o d e l d_k=d_v=d_{model} dk=dv=dmodel的多头注意力实现如下:

python">
class MHA(nn.Module):def __init__(self,C_in,dmodel,num_head=8,p_drop=0.2):super(MHA, self).__init__()self.QW=nn.Linear(C_in,dmodel)self.KW=nn.Linear(C_in,dmodel)self.VW=nn.Linear(C_in,dmodel)self.dp=nn.Dropout(p_drop)self.W_concat=nn.Linear(dmodel,dmodel)self.n_head=num_headself.p_drop=p_dropself.depth=dmodel//num_headdef forward(self,X,casual=True):B,L,C=X.shapeQ=self.QW(X)K=self.KW(X)V=self.VW(X)Q=Q.reshape((B,L,self.n_head,-1)).permute(0,2,1,3)K=K.reshape((B,L,self.n_head,-1)).permute(0,2,1,3)V=V.reshape((B,L,self.n_head,-1)).permute(0,2,1,3)atten=Q.matmul(K.transpose(2,3))if casual:mask=torch.triu(torch.ones(L,L))atten=torch.where(mask==1,atten,torch.ones_like(atten)*(-2**32+1))atten=torch.softmax(atten,dim=-1)atten=self.dp(atten)out=torch.matmul(atten,V)/self.depth**(1/2)out=out.permute(0,2,1,3).reshape(B,L,-1)out=self.W_concat(out)return outif __name__=="__main__":input=torch.rand(10,5,3)model=MHA(3,64,4)res=model(input)

http://www.ppmy.cn/embedded/113475.html

相关文章

2024.9.17 Python

1.现有字典 d{‘a’:24,’g’:52,’l’:12,’k’:33}请按字 典中的 value值进行排序? sorted(d.items(),key lambda x:x[1]) [1]换成0即可变成按照键排序 2.del 列表名[index]:删除指定索引的数据 3.列表名…

Oozie

Oozie 是 Apache Hadoop 生态系统中的一个工作流调度和协调框架,用于管理和执行定时的 Hadoop 任务。它允许用户定义复杂的工作流来协调多个不同的 Hadoop 任务(如 MapReduce、Hive、Pig 等)的执行,并支持任务间的依赖关系。Oozie…

Appium环境搭建及元素定位

Appium是一个开源测试自动化框架,可用于原生,混合和移动Web应用程序测试。它使用WebDriver 协议驱动iOS,Android和Windows应用程序。 01 环境搭建步骤 Appium环境安装: 第一步安装 appium 桌面版客户端 Appium-1.12.1.dmg(MAC…

数组学习内容

动态初始化 只给长度,数据类型【】 数组名new 数据类型【数组长度】 内存图

Java项目实战II基于Java+Spring Boot+MySQL的作业管理系统设计与实现(源码+数据库+文档)

目录 一、前言 二、技术介绍 三、系统实现 四、论文参考 五、核心代码 六、源码获取 全栈码农以及毕业设计实战开发,CSDN平台Java领域新星创作者 一、前言 在教育信息化的大潮中,作业管理作为教学过程中的重要环节,其效率与效果直接影…

太阳能光伏板航拍红外图像缺陷分类数据集

太阳能光伏板航拍红外图像缺陷分类数据集 一、数据集简介 太阳能光伏板的性能直接影响到光伏发电系统的效率和可靠性。随着无人机和红外成像技术的发展,通过航拍红外图像对光伏板进行缺陷检测已成为一种高效且准确的方法。本数据集包含11种不同的缺陷分类&#xf…

k8s独立组件ingress,七层转发

一、K8S的Service 1、Service的作用 Service的作用体现在两个方面: 1、集群内部:不断跟踪pod的变化,更新endpoints中的pod对象,基于pod的IP地址不断变化的一种服务发现机制,也可以实现负载均衡,四层代理…

python内置模块pathlib.Path类操作目录和文件

python自带的pathlib模块提供了很多路径相关的功能,而pathlib.Path 是pathlib 模块中的一个核心类,它代表了文件系统中的一个路径,实现功能比如创建、删除、移动文件,读取和写入文件内容,遍历目录等。 Path 类跟os.pa…