动手实现Multi-Head Attention

news/2024/9/23 9:23:42/

文章目录

  • 重点讲解
  • 代码实现
    • 流程图(维度变换示意图)
    • self-attention示例
    • 加入mask示例

MultiHeadAttention 是 Transformer 模型中的一个核心组件,它允许模型在处理序列的每个位置时同时考虑来自多个“视角”(即头部)的信息。这样做可以提高模型对不同位置关系的理解能力。

重点讲解

在这里插入图片描述
在这里插入图片描述

在这里插入图片描述

主要步骤:

  • 线性变换得到QKV,并将QKV分割为多头
  • 计算缩放点积注意力(注意mask可选)
  • 拼接多头
  • 最后再进行一次线性变换

代码实现

下面,我将使用 PyTorch 框架实现一个基本的 MultiHeadAttention 模块。

python">import torch
import torch.nn as nn
import torch.nn.functional as Fimport mathclass MultiHeadAttention(nn.Module):def __init__(self, d_model, num_heads):super(MultiHeadAttention, self).__init__()assert d_model % num_heads == 0, "d_model must be divisible by num_heads"self.d_model = d_modelself.num_heads = num_headsself.depth = d_model // num_heads# 定义线性层和输出线性层self.query_linear = nn.Linear(d_model, d_model)self.key_linear = nn.Linear(d_model, d_model)self.value_linear = nn.Linear(d_model, d_model)self.final_linear = nn.Linear(d_model, d_model)def split_heads(self, x, batch_size):"""分割最后一个维度到 (num_heads, depth).转置结果使得形状为 (batch_size, num_heads, seq_length, depth)"""x = x.view(batch_size, -1, self.num_heads, self.depth)return x.permute(0, 2, 1, 3)def forward(self, query, key, value, mask=None):batch_size = query.size(0)# 1. 线性层和分割到多头query = self.split_heads(self.query_linear(query), batch_size)key = self.split_heads(self.key_linear(key), batch_size)value = self.split_heads(self.value_linear(value), batch_size)# 2. 缩放点积注意力scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(self.depth)if mask is not None:scores = scores.masked_fill(mask == True, float('-inf'))attention_weights = F.softmax(scores, dim=-1)# 3. 将注意力权重应用到值上output = torch.matmul(attention_weights, value)# 4. 连接头部output = output.permute(0, 2, 1, 3).contiguous()output = output.view(batch_size, -1, self.d_model)# 5. 最后一次线性变换output = self.final_linear(output)return output

流程图(维度变换示意图)

在这里插入图片描述

self-attention示例

python">   d_model = 512  # 模型维度num_heads = 8  # 头数mha = MultiHeadAttention(d_model, num_heads)# 创建随机数据batch_size = 4seq_length = 60x = torch.rand(batch_size, seq_length, d_model)  # 输入假设维度为 (batch_size, seq_length, d_model)output = mha(x, x, x)  # 自注意力机制,qkv的输入相同;而cross-attention中,query来自decoder,kv来自encoderprint(output.shape)  

加入mask示例

解码器的自注意力层需要确保当前位置只能注意到前面的位置(包括当前位置),而不是未来的位置。这通常通过一个未来位置掩码实现,它是一个下三角矩阵

python">import torchdef generate_square_subsequent_mask(seq_len):"""生成一个未来步骤掩码,用于解码器中防止看到未来信息。"""mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1).bool()return maskd_model = 512  # 模型维度
num_heads = 8  # 头数
mha = MultiHeadAttention(d_model, num_heads)# 创建随机数据
batch_size = 4
seq_length = 60
x = torch.rand(batch_size, seq_length, d_model)  # 输入假设维度为 (batch_size, seq_length, d_model)# 生成掩码并将其应用于解码器的自注意力层
future_mask = generate_square_subsequent_mask(seq_length).to(x.device)
output = mha(x, x, x, mask=future_mask)  # 自注意力机制
print(output.shape)  # 应为 (batch_size, seq_length, d_model)

注意广播机制
在 PyTorch 中,masked_fill 函数可以很灵活地处理维度差异情况,通过广播(broadcasting)机制来匹配维度。

在这里插入图片描述


http://www.ppmy.cn/news/1425643.html

相关文章

压缩机回油控制逻辑

压缩机回油控制 1)在压缩机的运行过程中,当实际转速小于【回油最小频率】时,则开始回油时间计时; 2)当回油计数时间累计达到设定的【回油周期】或该压缩机在 1 小时内启停达到【回油启停次数】, 且当压缩机…

探索Web3的奇迹:数字时代的新前景

在数字化时代的潮流中,我们不可避免地迎来了一个全新的篇章——Web3时代的到来。在这个时代中,区块链技术作为数字化世界的核心,正在重塑着我们的生活方式、经济模式以及社会结构。在Web3时代,我们将目睹着一个以去中心化、透明化…

Day91:API攻防-接口安全SOAPOpenAPIRESTful分类特征导入项目联动检测

目录 API分类特征-SOAP&OpenAPI&RESTful API分类特征 API常见漏洞 API检测流程 API检测项目-Postman&APIKit&XRAY 工具自动化-SOAP - WSDL Postman 联动burpxray APIKit插件(可联动xray) 工具自动化-OpenApi - Swagger Postman 联动burpxray APIKit…

MySQL—MySQL架构

MySQL—MySQL架构 MySQL逻辑架构图如下: Connectors连接器:负责跟客户端建立连接;Management Serveices & Utilities系统管理和控制工具;Connection Pool连接池:管理用户连接,监听并接收连接的请求,转发所有连接的…

服务器清理挖矿问题

top -c ps -ef netstat -antp # 查所有端口链接 ls -al /proc/$PID/exe # 查执行文件 kill -9 $PID # 杀进程 // 查文件 /usr/lib/systemd/system /usr/lib/systemd/system/multi-user.target.wants /etc/rc.local /etc/inittab /etc/rc0.d/ /etc/rc1.d/ /etc/rc2.d/…

会议室预约小程序开源版开发

会议室预约小程序开源版开发 支持设置免费预约和付费预约、积分兑换商城、积分签到等 会议室类目,提供多种类型和设施的会议室选择,满足不同会议需求。 预约日历,展示会议室预约情况,方便用户选择空闲时段。 预约记录&#xff0…

深入剖析跨境电商平台风控机制,探索测评安全与稳定的秘诀

在跨境电商测评市场鱼龙混杂的当下,测评过程中可能隐藏的陷阱保持高度警觉。多年的测评经验告诉我们,选择一个适合的测评系统对于项目的成功至关重要。近年来,测评技术如雨后春笋般涌现,市场上涌现出众多测评系统,覆盖…

MacOS Python版本管理(pyenv)

1. 通过 homebrew 安装 pyenv brew update brew install pyenv 2. 修改 zsh profile 否则通过pyenv切换python版本会不生效 # 编辑 .zshrc or ~/.bash_profile vim ~/.zshrc# 在配置下面增加 export PYENV_ROOT"$HOME/.pyenv" export PATH"$PYENV_ROOT/shi…