【深度学习中的注意力机制1】11种主流注意力机制112个创新研究paper+代码——缩放点积注意力(Scaled Dot-Product Attention)

embedded/2025/2/21 8:16:59/

深度学习中的注意力机制1】11种主流注意力机制112个创新研究paper+代码——缩放点积注意力(Scaled Dot-Product Attention)

深度学习中的注意力机制1】11种主流注意力机制112个创新研究paper+代码——缩放点积注意力(Scaled Dot-Product Attention)


文章目录

  • 深度学习中的注意力机制1】11种主流注意力机制112个创新研究paper+代码——缩放点积注意力(Scaled Dot-Product Attention)
  • 前言
  • 1. 起源与提出
  • 2. 原理
  • 3. 发展
  • 4. 代码实现
  • 5. 总结


欢迎宝子们点赞、关注、收藏!欢迎宝子们批评指正!
祝所有的硕博生都能遇到好的导师!好的审稿人!好的同门!顺利毕业!

大多数高校硕博生毕业要求需要参加学术会议,发表EI或者SCI检索的学术论文会议论文:
可访问艾思科蓝官网,浏览即将召开的学术会议列表。会议入口:https://ais.cn/u/mmmiUz

前言

“缩放点积注意力”(Scaled Dot-Product Attention)是深度学习中注意力机制的重要组成部分。该机制最初由Vaswani et al.在2017年提出的Transformer模型中引入,并在自然语言处理(NLP)和计算机视觉等领域取得了巨大成功。

1. 起源与提出

注意力机制最早应用于机器翻译和序列生成任务中,旨在通过关注输入序列中最相关的部分来生成输出序列。而“缩放点积注意力”是Transformer模型中提出的一种高效的注意力机制,用来替代传统的RNN或LSTM在处理序列任务时的长依赖性问题。

缩放点积注意力的核心是计算输入序列之间的相似度,并将相似度作为权重来生成加权的输出序列。这种方法大大提升了模型并行计算的效率,并且适用于处理大规模数据。

2. 原理

缩放点积注意力的核心计算如下:

  • 给定三个输入:Query(Q)、Key(K)和Value(V),我们计算Query和Key的点积,以衡量每个Query与Key之间的相似性。
  • 将点积的结果除以一个缩放因子(通常是 d k \sqrt{d_k} dk ,其中 d k d_k dk是Key的维度),以避免点积值过大导致梯度消失问题。
  • 然后对这些相似性分数通过softmax进行归一化,得到注意力权重。
  • 最后,使用这些权重对Value进行加权求和,生成最终的输出。

数学公式如下:

在这里插入图片描述

  • 这里的Q,K,V分别代表查询、键和值向量矩阵。
  • d k \sqrt{d_k} dk 是缩放因子,目的是防止点积结果过大。

3. 发展

缩放点积注意力的提出,使得Transformer模型摆脱了对RNN和LSTM等循环神经网络的依赖,不仅提高了模型的并行能力,还显著提升了处理长距离依赖的能力。在Transformer的基础上,缩放点积注意力被广泛应用于语言模型(如BERT、GPT系列)、图像处理(如ViT)、以及多模态任务中。

此外,缩放点积注意力的思想还被发展出多头注意力机制(Multi-Head Attention),通过多个不同的注意力头来增强模型的表达能力和信息捕捉能力

4. 代码实现

下面是缩放点积注意力的Python实现,基于PyTorch框架:

import torch
import torch.nn as nn
import torch.nn.functional as Fclass ScaledDotProductAttention(nn.Module):def __init__(self):super(ScaledDotProductAttention, self).__init__()def forward(self, Q, K, V, mask=None):# 计算Q和K的点积, 得到注意力分数scores = torch.matmul(Q, K.transpose(-2, -1)) / torch.sqrt(torch.tensor(K.size(-1), dtype=torch.float32))# 如果有mask,设置为非常小的负数来忽略该位置if mask is not None:scores = scores.masked_fill(mask == 0, -1e9)# 对分数进行softmax归一化,得到注意力权重attention_weights = F.softmax(scores, dim=-1)# 使用注意力权重对V进行加权求和output = torch.matmul(attention_weights, V)return output, attention_weights# 示例输入
batch_size = 2
seq_len = 5
d_k = 4# 随机初始化Query、Key、Value
Q = torch.rand(batch_size, seq_len, d_k)
K = torch.rand(batch_size, seq_len, d_k)
V = torch.rand(batch_size, seq_len, d_k)# 实例化缩放点积注意力
attention = ScaledDotProductAttention()
output, attention_weights = attention(Q, K, V)print("输出:", output)
print("注意力权重:", attention_weights)

代码逐句解释

import torch 和 import torch.nn as nn:

  • 导入PyTorch库,用于张量操作和神经网络构建。

class ScaledDotProductAttention(nn.Module):

  • 定义一个缩放点积注意力的类,继承自torch.nn.Module

def forward(self, Q, K, V, mask=None):

  • 定义前向传播函数,接收Query、Key、Value和可选的mask作为输入。

scores = torch.matmul(Q, K.transpose(-2, -1)) / torch.sqrt(torch.tensor(K.size(-1), dtype=torch.float32)):

  • 首先,计算Query和Key的点积。
  • 然后将点积结果除以 d k \sqrt{d_k} dk ,这里K.size(-1)表示Key的维度。

if mask is not None: scores = scores.masked_fill(mask == 0, -1e9):

  • 如果提供了mask,将需要屏蔽的部分的注意力分数设置为非常小的值(-1e9),确保这些位置的权重接近于0。

attention_weights = F.softmax(scores, dim=-1):

  • 对注意力分数进行softmax归一化,生成权重,使得所有权重的和为1。

output = torch.matmul(attention_weights, V):

使用归一化的注意力权重对Value矩阵进行加权求和,生成最终的输出。
return output, attention_weights:

  • 返回最终的加权输出和注意力权重,供后续使用或分析。

初始化与使用部分:

  • Q, K, V 是随机生成的输入张量,模拟实际的Query、Key和Value矩阵。
  • 通过调用attention(Q, K, V)来计算最终的输出和注意力权重。

5. 总结

“缩放点积注意力”极大提高了Transformer模型的性能,使其在许多任务上取得了突破性成果。这个机制本质上是一种加权求和操作,通过Query和Key的相似性确定每个Value的重要性。其背后的原理虽然简单,但其高效性和扩展性使得它在多个领域得到了广泛应用。


http://www.ppmy.cn/embedded/129115.html

相关文章

Vue项目兼容IE11

配置Vue项目兼容IE11详解 Vue 不支持 IE8 及以下版本,因为 Vue 使用了 IE8 无法模拟的 ECMAScript 5 特性。但对于 IE9,Vue 底层是支持。 由于开发过程中,我们经常会使用一些第三方插件或组件,对于这些组件,有时我们…

牛客周赛63(C++实现)

🌈个人主页:Yui_ 🌈Linux专栏:Linux 🌈C语言笔记专栏:C语言笔记 🌈数据结构专栏:数据结构 🌈C专栏:C 文章目录 1.小红的好数1.1 题目描述1.2 思路1.3 代码 2.…

elementUI,设置日期,只能选择过去的和今天的日期

在 el-date-picker 组件中加&#xff1a;:picker-options"pickerOptions" <el-form-item label"票据生成日期&#xff1a;"> <el-date-picker v-model"date1" type"daterange" range-separator"至" value-format&…

Linux网络编程(七)-TCP协议客户端及代码实现

1.TCP的客户端代码流程简述 这一章将为大家讲解Socket通信中客户端的实现过程&#xff0c;还是先上图&#xff0c;请大家了解客户端的步骤 可以看到&#xff0c;相比服务端&#xff0c;客户端的步骤简单的很多。事实上这种情况比较多&#xff0c;比如一个服务端会有多个客户端…

2024年网络安全(黑客技术)三个月自学手册

&#x1f91f; 基于入门网络安全/黑客打造的&#xff1a;&#x1f449;黑客&网络安全入门&进阶学习资源包 前言 什么是网络安全 网络安全可以基于攻击和防御视角来分类&#xff0c;我们经常听到的 “红队”、“渗透测试” 等就是研究攻击技术&#xff0c;而“蓝队”、…

【RestTemplate】重试机制详解

在现代的微服务架构中&#xff0c;服务间的网络调用是常见的场景。在这种情况下&#xff0c;网络请求可能会因为多种原因失败&#xff0c;比如超时、服务不可用等。为了提升系统的鲁棒性&#xff0c;我们可以为 RestTemplate 配置重试机制。本文将详细探讨如何为 RestTemplate …

Tomcat日志文件详解及catalina.out日志清理方法

目录 前言1. Tomcat日志文件详解1.1 catalina.out1.2 localhost_access_log1.3 catalina.<date>.log1.4 host-manager.<date>.log 和 manager.<date>.log1.5 localhost.<date>.log 2. catalina.out文件管理与清理方法2.1 为什么不能直接删除catalina.o…

Java 二分查找算法详解及通用实现模板案例示范

1. 引言 二分查找&#xff08;Binary Search&#xff09;是一种常见的搜索算法&#xff0c;专门用于在有序数组或列表中查找元素的位置。它通过每次将搜索空间缩小一半&#xff0c;从而极大地提高了查找效率。相比于线性查找算法&#xff0c;二分查找的时间复杂度为 O(log n)&…