深度学习注意力机制类型总结pytorch实现代码

news/2024/11/12 20:14:26/

一、注意力机制的基本原理

深度学习中,注意力机制(Attention Mechanism)已经成为一种重要的技术。意力机制通过动态调整模型的注意力权重,来突出重要信息,忽略不重要的信息,大大提高了模型的效果

注意力机制的基本思想是:在处理输入序列时,模型可以根据当前的上下文动态地选择关注哪些部分。具体来说,注意力机制通过计算查询向量(Query)、键向量(Key)之间的相似度来确定注意力权重,然后对值向量(Value)进行加权求和,得到最终的输出。(当K=V时,就是普通的注意力机制)

二、注意力机制的计算步骤

  • 计算注意力得分:对于每个查询向量Q和键向量K,计算它们之间的注意力得分
  • 计算注意力权重:使用Softmax函数将注意力得分转换为注意力权重,使得权重和为1
  • 加权求和值向量:对值向量V进行加权求和,得到最终的输出

三、常见的注意力机制类型

不同的注意力机制差别主要在于注意力得分的计算方式

加性模型

  • 定义:加性注意力通过将查询向量和键向量进行拼接后,经过一个前馈神经网络来计算注意力权重
  • 公式Attention(Q,K,V) = softmax(e)V;e_{i,j} = v^T tanh(W_qQ_i + W_kK_j),其中v、W_q、W_k是可学习的参数

点积模型

  • 定义:点积注意力通过计算查询向量和键向量的点积来确定注意力权重,,在实现上可以更好地利用矩阵乘积,从而计算效率更高
  • 公式Attention(Q,K,V)=softmax(QK^T)V

缩放点积模型

  • 定义:当输入向量的维度较高时,点积的值可能会较大,从而导致softmax的梯度会比较小,缩放点积可以较好地解决这个问题
  • 公式:Attention(Q,K,V)=softmax(\frac{QK^T}{\sqrt{d}})V,其中d是键向量K的维度

双线性模型

  • 定义:双线性模型是一种泛化的点积模型,可以看成是分别对Q/K进行线性变换后,计算点积
  • 公式:Attention(Q,K,V)=softmax(QWK^T)V

多头注意力

  • 定义:多头注意力通过并行计算多个不同的自注意力,然后将结果拼接起来,每个注意力关注输入的不同部分,使得信息更加丰富
  • 公式MultiHead(Q,K,V) = concat(head_1,...,head_n),其中head_i = Attention(Q_i,K_i,V_i)

四、多头自注意力机制的pytorch手动实现

自注意力机制的基本思想是,在处理序列数据时,每个元素都可以与序列中的其他元素建立关联,而不仅仅是依赖于相邻位置的元素。它通过计算元素之间的相对重要性来自适应地捕捉元素之间的长程依赖关系

具体而言,对于序列中的每个元素,自注意力机制计算其与其他元素之间的相似度,并将这些相似度归一化为注意力权重。然后,通过将每个元素与对应的注意力权重进行加权求和,可以得到自注意力机制的输出。下面是多头自注意力机制的pytorch手动实现代码:

import torch
import torch.nn as nnclass MultiheadAttention(nn.Module):def __init__(self, hidden_size, head_num, dropout):super(MultiheadAttention, self).__init__()# hidden_size:每个词输出的向量维度# head_num:多头注意力的数量self.hidden_size = hidden_sizeself.head_num = head_num# 强制 hidden_size 必须整除 head_numassert hidden_size % head_num == 0# 定义 W_q, W_k, W_v矩阵self.w_q = nn.Linear(hidden_size, hidden_size)self.w_k = nn.Linear(hidden_size, hidden_size)self.w_v = nn.Linear(hidden_size, hidden_size)self.fc = nn.Linear(hidden_size, hidden_size)self.drop = nn.Dropout(dropout)# 缩放self.scale = torch.sqrt(torch.FloatTensor([hidden_size // head_num]))def forward(self, query, key, value, mask=None):batch_size = query.shape[0]# Q: [64,12,300], batch_size 为 64,有 12 个词,每个词的 Query 向量是 300 维# K/V: [64,10,300], batch_size 为 64,有 10 个词,每个词的 Query 向量是 300 维Q = self.w_q(query)K = self.w_k(key)V = self.w_v(value)# Q: [64,12,300] 拆分多组注意力 -> [64,12,6,50] 转置得到 -> [64,6,12,50]# K/V: [64,10,300] 拆分多组注意力 -> [64,10,6,50] 转置得到 -> [64,6,10,50]# 转置是为了把注意力的数量 6 放到前面,把 10 和 50 放到后面,方便下面计算Q = Q.view(batch_size, -1, self.head_num, self.hidden_size // self.head_num).permute(0, 2, 1, 3)K = K.view(batch_size, -1, self.head_num, self.hidden_size // self.head_num).permute(0, 2, 1, 3)V = V.view(batch_size, -1, self.head_num, self.hidden_size // self.head_num).permute(0, 2, 1, 3)# 第 1 步:Q 乘以 K的转置,除以scale# [64,6,12,50] * [64,6,50,10] = [64,6,12,10]attention = torch.matmul(Q, K.permute(0, 1, 3, 2)) / self.scale# 把 mask 不为空,那么就把 mask 为 0 的位置的 attention 分数设置为 -1e10if mask is not None:attention = attention.masked_fill(mask == 0, -1e10)# 第 2 步:计算上一步结果的 softmax,再经过 dropout,得到 attention。# 注意,这里是对最后一维做 softmax,也就是在输入序列的维度做 softmax,attention: [64,6,12,10]attention = self.drop(torch.softmax(attention, dim=-1))# 第 3 步,attention结果与V相乘,得到多头注意力的结果# [64,6,12,10] * [64,6,10,50] = [64,6,12,50]x = torch.matmul(attention, V)# 因为 query 有 12 个词,所以把 12 放到前面,把 5 和 60 放到后面,方便下面拼接多组的结果# x: [64,6,12,50] 转置-> [64,12,6,50]x = x.permute(0, 2, 1, 3).contiguous()# 这里的矩阵转换就是:把多组注意力的结果拼接起来# x: [64,12,6,50] -> [64,12,300]x = x.view(batch_size, -1, self.hidden_size)x = self.fc(x)return x# 测试手动实现的 MultiheadAttention
if __name__ == "__main__":# batch_size 为 64,有 12 个词,每个词的 Query 向量是 300 维query = torch.rand(64, 12, 300)# batch_size 为 64,有 12 个词,每个词的 Key 向量是 300 维key = torch.rand(64, 10, 300)value = torch.rand(64, 10, 300)attention = MultiheadAttention(hidden_size=300, head_num=6, dropout=0.1)output = attention(query, key, value)# output: torch.Size([64, 12, 300])print(output.shape)


http://www.ppmy.cn/news/1545842.html

相关文章

linux服务器通过手机USB共享网络

一、背景: 现网交付时,客户机房设备未接入互联网,需要联网拉去软件包时,通过手机USB共享网络给服务器。 二、测试环境: ubuntu、centos 三、操作步骤: 1、 确认networkmanager服务开启(cento…

命令行工具PowerShell使用体验

命令行工具PowerShell使用 PowerShell是微软开发的一种面向对象的命令行Shell和脚本语言环境,它允许用户通过命令行的方式管理操作系统。相较于传统CMD,PowerShell增加了面向对象的程序设计框架,拥有更强大的功能和扩展性。使用PowerShell可…

若Git子模块的远端地址发生了变化本地应该怎么调整

文章目录 前言git submodule 相关命令解决方案怎么保存子模块的版本呢总结 前言 这个问题复杂在既有Git又有子模块,本身Git的门槛就稍微高一点,再加上子模块的运用,一旦出现这种远端地址发生修改的情况会让人有些懵,不知道怎么处…

anaconda 安装笔记Ubuntu20

愿我们终有重逢之时,而你还记得我们曾经讨论的话题。 group 868373192 second group 277356808 在 Ubuntu 20.04 上安装 Anaconda 的特定版本(例如 4.2)可以通过以下步骤完成。请注意,Anaconda 4.2 是一个较旧的版本,…

07 Oracle数据库恢复基础解析:从检查点到归档,一步步构建数据安全防线

文章目录 Oracle数据库恢复基础解析:从检查点到归档,一步步构建数据安全防线一、检查点(Checkpoint)1.1 检查点定义1.2 检查点重要性1.3 检查点工作原理1.4 手动触发检查点 二、日志(Redo Log)2.1 日志定义…

TVM计算图分割--LayerGroup

文章目录 介绍Layergroup调研TVM中的LayergroupTVM Layergroup进一步优化MergeCompilerRegions处理菱形结构TVM中基于Pattern得到的子图TPUMLIR地平线的Layergroup介绍 Layergroup目前没找到严格、明确的定义,因为不同厂家的框架考虑的因素不同,但是基本逻辑是差不多的。一般…

单元测试日志打印相关接口及类 Logger

LoggerFactory 简介 单元测试常用日志打印工具LoggerFactory。 LoggerFactory 代码结构 LoggerFactory 是 JUnit 平台中的一个类,用于创建 Logger 实例。它被设计用于提供日志记录功能,使得 JUnit 在执行测试时能够记录信息、警告、错误等。 LoggerFact…

关于 AJAX 与 Promise

AJAX (Asynchronous JavaScript and XML) AJAX(Asynchronous JavaScript and XML)是一种在网页上异步传输数据的技术,允许网页在不重新加载整个页面的情况下更新部分内容。这提高了用户的体验,因为用户不需要等待整个页面重新加载…