Transformer基础 多头自注意力机制

devtools/2025/2/12 2:10:29/

# 1. **自注意力机制**:Transformer通过自注意力机制能够高效地计算序列内所有元素之间的关系,这使得模型能够捕捉到长距离依赖,无论这些依赖的距离有多远。

# 2. **并行化处理**:与RNN不同,Transformer可以同时处理整个序列,这极大地提高了训练效率。

# 3. **无需递归结构**:Transformer完全摒弃了递归结构,这意味着不存在梯度消失或爆炸的问题,同时也使得模型能够更容易地学习长距离的依赖关系

# tf 的三个关键 多头 掩码 位置编码

# tf 的输入部分 位置编码和词嵌入

# 为什么要有位置编码  因为tf是并行运行的并不知道谁先谁后   rnn是必定有先后的

# **位置编码**:由于Transformer本身不具备处理序列顺序的能力,通过添加位置编码到输入序列,模型能够利用序列中元素的位置信息。

# 位置编码 实际上就是一个向量 把它跟词向量结合 放在词向量中

# ##### 计算公式

# 位置编码可以有多种实现方式,Transformer原始论文中提出的位置编码是通过正弦和余弦函数来计算的,

# **这样做的好处是能够让模型学习到相对位置信息**,因为这些函数对位置的偏移是可预测的。对于序列中的每个位置pos,和每个维度 i ,位置编码 ( pos , i ) 是这样计算的

# 根据三角函数的公式 知道1,2的位置即3 的位置也知道了

# 自注意的原理

# 本单词向其他单词发送 询问Q  Q=权重矩阵q X 本词的词向量q  其他单词回应  K K=权重矩阵k X 其他词的词向量k 得到阿尔法a=q点乘k 点乘的原理是查看两个向量的相似度

# 自注意的代码实现

import torch

import torch.nn as nn

from torch.nn import functional as F

import numpy as np

def self_att():

    x = torch.randn(3,2)

    W_q=torch.eye(2)

    W_k=torch.eye(2)

    W_v=torch.eye(2)

    Q=torch.matmul(x,W_q)

    K=torch.matmul(x,W_k)

    V=torch.matmul(x,W_v)

    d_k=2

    score=torch.matmul(Q,K.T)/np.sqrt(torch.tensor(d_k))

    score.masked_fill(0,-1e10)

    att_weight=F.softmax(score,dim=-1)

    res=torch.matmul(att_weight,V)

    return res

# 多头自注意力的概念

# 就是把 QKV 分别拆成多个q1 q2 k1 k2 v1 v2 然后 每个位置的对应走一遍 流程 最后的结果是散的 拼起来然后乘一个W 组合为一个结果

# tf一般分8个头

# 为什么要分成多个

# 原来一个词的位置只能一个位置表示 现在分8个有8个位置了 更方便找到合适的位置

# 掩码的作用 把矩阵中的0变为一个特别小的值 然后经过softmax后成为真的0 不然原本为0 softmax后为一个小的数

class MutiSelfAtt(nn.Module):

    def __init__(self,d_model,num_head, *args, **kwargs):

        super().__init__(*args, **kwargs)

        self.num_head=num_head

        self.d_model=d_model

        self.dim=d_model//num_head

        # 定义三个线性层 wx

        self.q_linear=nn.Linear(d_model,d_model)

        self.k_linear=nn.Linear(d_model,d_model)

        self.v_linear=nn.Linear(d_model,d_model)

        self.linear=nn.Linear(d_model,d_model)

    def forward(self,x): # x是batch_size,seq_len,emb_dim

        Q=self.q_linear(x)

        K=self.k_linear(x)

        V=self.v_linear(x)

        batch_size=x.shape[0]

        print("------------",Q.shape)

        Q=Q.view(batch_size,-1,self.num_head,self.dim) # 拆为 batch_size,seq_len,head,dim

        K=K.view(batch_size,-1,self.num_head,self.dim)

        V=V.view(batch_size,-1,self.num_head,self.dim)

        # 变为batch_size,head,seq_len,dim 对应头的每个对应位置相乘

        Q=Q.transpose(1,2)

        K=K.transpose(1,2)

        V=V.transpose(1,2)

        res=self.self_att(Q,K,V)

        res=res.transpose(1,2).contiguous().view(batch_size,-1,self.d_model)

        res=self.linear(res)

        return res

    def self_att(self,Q,K,V,mask=None):

        d_k=self.dim

        score=torch.matmul(Q,K.transpose(-1, -2))/np.sqrt(torch.tensor(d_k))

        if mask is not None:

            score=score.masked_fill(mask==0,-1e10)

        att_weight=F.softmax(score,dim=-1)

        res=torch.matmul(att_weight,V)

        return res

data=torch.randn(3,2,64)

msa=MutiSelfAtt(64,4)

res=msa(data)

print(res.shape)

# 残差连接(Add)

# 层归一化(Norm)跟批量归一化不同

# 前馈神经网络子层 Feed ForwardFeed Forward FFN

# 全连接层

class FFN(nn.Module):

    def __init__(self,d_model,d_ff=256 ,*args, **kwargs):

        super().__init__(*args, **kwargs)

        self.d_model=d_model

        self.ffn=nn.Sequential(

            nn.Linear(d_model,d_ff),

            nn.ReLU(),

            nn.Linear(d_ff,d_model)

        )

        self.norm_res=nn.LayerNorm(self.d_model)

    def forward(self,x):

        res=x

        output=self.ffn(x)

        # 残差和层归一化

        output=self.norm_res(output)

        output=res+output

        return output


http://www.ppmy.cn/devtools/158081.html

相关文章

如何使用Socket编程在Python中实现实时聊天应用

在现代的网络应用中,实时聊天功能成为了不可或缺的一部分。从社交平台到在线客服系统,实时聊天应用广泛存在。Python提供了强大的socket库,可以帮助我们轻松实现基于TCP协议的实时聊天功能。本文将介绍如何通过Socket编程在Python中实现一个简…

Git―分支管理

Git ⛅创建&切换&合并分支⛅删除分支⛅合并冲突⛅合并模式⛅Bug 分支⛅强制删除分支 master → 主分支 # 查看本地所有分支 git branch分支前面的*, 代表当前所在的分支 图中当前所在的分支为master ⛅创建&切换&合并分支 # 创建分支 git branch "bra…

Spring框架学习大纲

Spring框架学习大纲 一、Spring基础入门 Spring概述 Spring框架发展历史与核心优势Spring核心模块组成(IoC、AOP、Data Access、Web MVC等)Spring与传统Java EE开发对比 控制反转(IoC)与依赖注入(DI) IoC…

贪心算法_翻硬币

蓝桥账户中心 依次遍历 不符合条件就反转 题目要干嘛 你就干嘛 #include <bits/stdc.h>#define endl \n using namespace std;int main() {ios::sync_with_stdio(0); cin.tie(0); cout.tie(0); string s; cin >> s;string t; cin >> t;int ret 0;for ( i…

Android双屏异显Presentation接口使用说明

在点餐、收银、KTV等场景,对于双屏异显的需求是非常多的,首先可以节省硬件成本。而现在的智能板卡很多运行Android系统,从Android4.2开始支持WiFi Display(Miracast)功能后,就开始支持双屏异显Presentation这套应用层接口了,下面以Android5.1系统来说明这套接口的使用要…

安卓开发用Java、Flutter、Kotlin的区别

在安卓开发中&#xff0c;Java、Kotlin 和 Flutter 是三种常见的技术选择&#xff0c;各有优缺点。以下是它们的区别&#xff1a; 1. Java 历史&#xff1a;Java 是安卓开发的传统语言&#xff0c;自安卓平台推出以来一直作为主要开发语言。成熟度&#xff1a;拥有丰富的库和…

基于 Linux 与 CloudFlare 的智能实时 CC/DDoS 防御方案

随着互联网的快速发展,网络安全问题日益严峻,尤其是 CC(Challenge Collapsar)攻击 和 DDoS(分布式拒绝服务)攻击 对网站和服务的威胁越来越大。为了应对这些攻击,许多企业和开发者选择使用 CloudFlare 作为防御工具。CloudFlare 提供了强大的 WAF(Web Application Fire…

C++设计模式 - 模板模式

一&#xff1a;概述 模板方法&#xff08;Template Method&#xff09;是一种行为型设计模式。它定义了一个算法的基本框架&#xff0c;并且可能是《设计模式&#xff1a;可复用面向对象软件的基础》一书中最常用的设计模式之一。 模板方法的核心思想很容易理解。我们需要定义一…