Transformer——多头注意力机制(Pytorch)

server/2024/9/24 8:27:41/

1. 原理图

2. 代码

import torch
import torch.nn as nnclass Multi_Head_Self_Attention(nn.Module):def __init__(self, embed_size, heads):super(Multi_Head_Self_Attention, self).__init__()self.embed_size = embed_sizeself.heads = headsself.head_dim = embed_size // headsself.queries = nn.Linear(self.embed_size, self.embed_size, bias=False)self.keys = nn.Linear(self.embed_size, self.embed_size, bias=False)self.values = nn.Linear(self.embed_size, self.embed_size, bias=False)self.fc_out = nn.Linear(self.embed_size, self.embed_size, bias=False)def forward(self,queries, keys, values, mask):N = queries.shape[0]  # batch_sizequery_len = queries.shape[1]  # sequence_lengthkey_len = keys.shape[1]  # sequence_length value_len = values.shape[1]  # sequence_lengthqueries = self.queries(queries)keys = self.keys(keys)values = self.values(values)# Split the embedding into self.heads pieces# batch_size, sequence_length, embed_size(512) --> # batch_size, sequence_length, heads(8), head_dim(64)queries = queries.reshape(N, query_len, self.heads, self.head_dim)keys = keys.reshape(N, key_len, self.heads, self.head_dim)values = values.reshape(N, value_len, self.heads, self.head_dim)# batch_size, sequence_length, heads(8), head_dim(64) --> # batch_size, heads(8), sequence_length, head_dim(64)queries = queries.transpose(1, 2)keys = keys.transpose(1, 2)values = values.transpose(1, 2)# Scaled dot-product attentionscore = torch.matmul(queries, keys.transpose(-2, -1)) / (self.head_dim ** (1/2))if mask is not None:score = score.masked_fill(mask == 0, float("-inf"))# batch_size, heads(8), sequence_length, sequence_lengthattention = torch.softmax(score, dim=-1)out = torch.matmul(attention, values)# batch_size, heads(8), sequence_length, head_dim(64) --># batch_size, sequence_length, heads(8), head_dim(64) --># batch_size, sequence_length, embed_size(512)# 为了方便送入后面的网络out = out.transpose(1, 2).contiguous().reshape(N, query_len, self.embed_size)out = self.fc_out(out)return outbatch_size = 64
sequence_length = 10
embed_size = 512
heads = 8
mask = NoneQ = torch.randn(batch_size, sequence_length, embed_size)  
K = torch.randn(batch_size, sequence_length, embed_size)  
V = torch.randn(batch_size, sequence_length, embed_size)  model = Multi_Head_Self_Attention(embed_size, heads)
output = model(Q, K, V, mask)
print(output.shape)

 


http://www.ppmy.cn/server/62254.html

相关文章

树结构添加分组,向上向下添加同级,添加子级

树结构添加分组&#xff0c;向上向下添加同级&#xff0c;添加子级 效果代码实现页面js 效果 代码实现 页面 <el-tree :data"treeData" :props"defaultProps" :expand-on-click-node"false":filter-node-method"filterNode" :ref&…

初学SpringMVC之 JSON 篇

JSON&#xff08;JavaScript Object Notation&#xff0c;JS 对象标记&#xff09;是一种轻量级的数据交换格式 采用完全独立于编程语言的文本格式来存储和表示数据 JSON 键值对是用来保存 JavaScript 对象的一种方式 比如&#xff1a;{"name": "张三"}…

Python 列表及其常用操作详解

在Python编程中&#xff0c;列表&#xff08;List&#xff09;是一种非常常见且重要的数据结构。列表是一个有序的集合&#xff0c;可以包含任意类型的元素。列表是可变的&#xff0c;这意味着你可以在列表创建后对其进行修改&#xff0c;如添加、删除和更新元素。本文将详细介…

萝卜快跑的「悖论」

本文将探讨无人车带来的出行变革与现有交通生态之间的冲突&#xff0c;以及如何寻找技术创新与社会伦理之间的平衡点。 「做无人车的初衷&#xff0c;不是为了抢出租车网约车司机的生意&#xff0c;而是为了更好的服务老百姓&#xff0c;提供一种新的出行方式。」百度副总裁王云…

语音识别概述

语音识别概述 一.什么是语音&#xff1f; 语音是语言的声学表现形式&#xff0c;是人类自然的交流工具。 图片来源&#xff1a;https://www.shenlanxueyuan.com/course/381 二.语音识别的定义 语音识别&#xff08;Automatic Speech Recognition, ASR 或 Speech to Text, ST…

MySQL与Redis优化

MySQL优化策略&#xff1a; 查询优化&#xff1a;使用EXPLAIN分析查询语句&#xff0c;优化JOIN操作&#xff0c;减少子查询和复杂的WHERE条件。索引优化&#xff1a;合理创建索引以加快查询速度&#xff0c;同时避免过度索引导致写性能下降。数据类型优化&#xff1a;使用合适…

使用Apache服务部署静态网站

前言&#xff1a;本博客仅作记录学习使用&#xff0c;部分图片出自网络&#xff0c;如有侵犯您的权益&#xff0c;请联系删除 目录 一、网站服务程序 ​二、配置服务文件参数 ​三、SELinux安全子系统 四、个人用户主页功能 ​五、虚拟网站主机功能 六、Apache的访问控制…

【linux】【深度学习】fairseq框架安装踩坑

直接pip install fairseq发现跑代码时候老是容易崩&#xff0c;所以选择用源码编译安装。 python环境选择3.8以上都行&#xff0c;我选择3.10 首先安装torch&#xff0c; 我选择安装pip install torch1.13.1 torchaudio0.13.1以及cuda 11.7 &#xff08;具体cuda根据个人显卡进…