【模型】Self-Attention

devtools/2024/10/15 19:34:20/

Self-Attention 机制(Self-Attention Mechanism)是近年来深度学习中,特别是自然语言处理(NLP)任务中广泛应用的一种机制,它最初出现在Transformer架构中,用于捕捉序列数据中的依赖关系。该机制的核心思想是:通过给定序列的所有元素之间分配不同的权重,模型能够灵活地关注序列中的相关部分,从而更好地理解上下文关系。

一、背景

传统的序列处理模型,如RNN(循环神经网络)或LSTM(长短期记忆网络),通常处理序列的方式是依赖于时间步(timestep)上的顺序信息。这意味着模型在捕捉长距离依赖(long-range dependencies)时,效率较低。而Self-Attention则通过全局的方式来捕捉序列中任意两个元素之间的关系,不需要按照时间步逐一处理。

二、Self-Attention 的工作原理

Self-Attention 通过计算 Query、Key、Value,让序列中的每个元素与序列的其他元素进行比较,计算它们的相似性。
然后,Self-Attention 根据这个相似性来更新每个元素的表示,使得每个元素的表示不再只是它自己的信息,而是结合了与序列中其他元素的关联信息。

假设输入序列为 {x1​,x2​,…,xn​},每个元素 xi 是一个向量。

具体步骤如下:

  1. 计算 Query、Key 和 Value

    • 对每个输入 xi,我们生成三个向量:Query (Q),Key (K),和 Value (V)。
      Query (Q) 表示当前元素“想要询问什么”,Key (K) 表示当前元素“可以被询问什么”,Value (V) 承载实际的输入信息。
      这些向量由输入向量通过三组不同的可学习权重矩阵线性变换得到:
      在这里插入图片描述

    其中,Wq、Wk、Wv 是可学习的权重矩阵。

  2. 计算注意力得分

    • 对第 i 个元素(第 i 个 Query 向量 Qi)来说,要与所有的 Key 向量进行相似性计算,即要计算 Qi与所有 Kj​(j=1,2,…,n)之间的相似性得分。通过点积来实现:
      在这里插入图片描述

    这里,dk​ 是 Key 向量的维度。缩放因子  是为了稳定训练过程,防止值过大。

    这些得分反映了输入序列中不同元素之间的相关性。

  3. Softmax 归一化

    • 计算完所有的相似度后,通过 Softmax 函数对这些得分进行归一化,得到权重分布:
      在这里插入图片描述

      这一步可以确保所有权重的和为 1。

  4. 加权求和得到输出

    • 最后,使用这些归一化后的权重 αij​ 对 Value 向量 Vj 进行加权求和,得到更新后的表示:

      在这里插入图片描述

这样,每个输入 xi 的输出不仅仅依赖于它自身,还依赖于与其他所有输入的关系权重。这使得模型能够更好地捕捉输入序列中远程词语之间的关系。

三、多头注意力机制

为了让模型在不同的子空间上捕捉不同的注意力模式,Multi-Head Attention(多头注意力)机制进一步扩展了 Self-Attention。具体来说,多头注意力会将输入的 Query、Key 和 Value 向量分成多个头(head),在每个头上分别执行注意力操作,最后将各头的输出进行拼接并投影回原空间。

这种方法使模型可以从多个角度理解输入序列中的不同部分,从而增强模型的表达能力。

四、Self-Attention 的优势

  • 并行化计算:不同于RNN等需要逐步处理输入序列,Self-Attention 允许并行计算,使得训练速度大大提升。
  • 捕捉长距离依赖关系:由于每个输入都可以关注整个序列的所有其他元素,Self-Attention 机制能够很好地捕捉长距离的依赖关系。
  • 可解释性:Attention 权重可以直观地解释模型关注了哪些输入,这为模型的决策过程提供了可解释性。

五、应用

Self-Attention 机制是 Transformer 模型的核心,并且广泛应用于多种任务中,包括:

  • 机器翻译:如 Google 的 Transformer 模型被广泛应用于机器翻译任务中。
  • 文本生成:如 GPT 系列模型中的自回归 Transformer 结构。
  • 文本分类和问答系统:BERT 等双向 Transformer 结构也基于 Self-Attention,广泛用于各种 NLP 任务。

总的来说,Self-Attention 机制极大地推动了深度学习领域,特别是自然语言处理技术的发展。

实现一个简化的 Self-Attention Transformer 机器翻译模型,它没有包括多头注意力(Multi-Head Attention)、位置编码(Positional Encoding)等功能:

import torch
import torch.nn as nn
import torch.optim as optim# 定义词汇表
vocab_src = ['<pad>', '<sos>', '<eos>', 'i', 'am', 'a', 'student', 'he', 'is', 'teacher', 'she', 'loves', 'apples', 'we', 'are', 'friends']
vocab_tgt = ['<pad>', '<sos>', '<eos>', 'je', 'suis', 'un', 'étudiant', 'il', 'est', 'professeur', 'elle', 'aime', 'les', 'pommes', 'nous', 'sommes', 'amis']# 创建词汇表映射
src_vocab_size = len(vocab_src)
tgt_vocab_size = len(vocab_tgt)
src_word2idx = {word: idx for idx, word in enumerate(vocab_src)}
tgt_word2idx = {word: idx for idx, word in enumerate(vocab_tgt)}
idx2tgt_word = {idx: word for word, idx in tgt_word2idx.items()}# 句子对
pairs = [["i am a student", "je suis un étudiant"],["he is a teacher", "il est un professeur"],["she loves apples", "elle aime les pommes"],["we are friends", "nous sommes amis"]
]# 将句子转换为索引
def sentence_to_idx(sentence, word2idx):return [word2idx[word] for word in sentence.split()]# 模型参数
embedding_dim = 128
hidden_dim = 128
num_layers = 1# 自注意力层
class SelfAttention(nn.Module):def __init__(self, dim):super(SelfAttention, self).__init__()self.query = nn.Linear(dim, dim)self.key = nn.Linear(dim, dim)self.value = nn.Linear(dim, dim)self.scale = dim ** 0.5def forward(self, x):Q = self.query(x)K = self.key(x)V = self.value(x)attn_weights = torch.softmax(torch.bmm(Q, K.transpose(1, 2)) / self.scale, dim=-1)output = torch.bmm(attn_weights, V)return output# 编码器
class Encoder(nn.Module):def __init__(self, input_dim, embedding_dim, hidden_dim):super(Encoder, self).__init__()self.embedding = nn.Embedding(input_dim, embedding_dim)self.attention = SelfAttention(embedding_dim)def forward(self, src):embedded = self.embedding(src)attn_output = self.attention(embedded)return attn_output# 解码器
class Decoder(nn.Module):def __init__(self, output_dim, embedding_dim, hidden_dim):super(Decoder, self).__init__()self.embedding = nn.Embedding(output_dim, embedding_dim)self.attention = SelfAttention(embedding_dim)self.fc = nn.Linear(embedding_dim, output_dim)def forward(self, tgt, encoder_output):embedded = self.embedding(tgt)attn_output = self.attention(embedded)output = self.fc(attn_output)return output# Seq2Seq 模型
class Seq2Seq(nn.Module):def __init__(self, encoder, decoder):super(Seq2Seq, self).__init__()self.encoder = encoderself.decoder = decoderdef forward(self, src, tgt):encoder_output = self.encoder(src)output = self.decoder(tgt, encoder_output)return output# 数据处理
def prepare_data(pairs):src_data = [torch.tensor(sentence_to_idx(pair[0], src_word2idx), dtype=torch.long) for pair in pairs]tgt_data = [torch.tensor(sentence_to_idx(pair[1], tgt_word2idx), dtype=torch.long) for pair in pairs]return src_data, tgt_data# 初始化模型
encoder = Encoder(src_vocab_size, embedding_dim, hidden_dim)
decoder = Decoder(tgt_vocab_size, embedding_dim, hidden_dim)
model = Seq2Seq(encoder, decoder)# 损失函数和优化器
criterion = nn.CrossEntropyLoss(ignore_index=src_word2idx['<pad>'])
optimizer = optim.Adam(model.parameters(), lr=0.001)# 训练模型
def train(model, src_data, tgt_data, epochs=100):model.train()for epoch in range(epochs):total_loss = 0for src, tgt in zip(src_data, tgt_data):src = src.unsqueeze(0)  # (1, seq_len)tgt_input = tgt[:-1].unsqueeze(0)  # 输入解码器,去掉 <eos>tgt_output = tgt[1:].unsqueeze(0)  # 解码器目标,去掉 <sos>optimizer.zero_grad()output = model(src, tgt_input)loss = criterion(output.view(-1, tgt_vocab_size), tgt_output.view(-1))loss.backward()optimizer.step()total_loss += loss.item()if (epoch + 1) % 10 == 0:print(f'Epoch {epoch+1}/{epochs}, Loss: {total_loss/len(src_data):.4f}')# 数据准备
src_data, tgt_data = prepare_data(pairs)# 训练模型
train(model, src_data, tgt_data, epochs=100)# 翻译函数
def translate(model, sentence):model.eval()src = torch.tensor(sentence_to_idx(sentence, src_word2idx), dtype=torch.long).unsqueeze(0)tgt = torch.tensor([tgt_word2idx['<sos>']], dtype=torch.long).unsqueeze(0)with torch.no_grad():encoder_output = model.encoder(src)for _ in range(10):  # 假设最大输出句长为10output = model.decoder(tgt, encoder_output)pred_token = output.argmax(2)[:, -1].item()tgt = torch.cat([tgt, torch.tensor([[pred_token]])], dim=1)if pred_token == tgt_word2idx['<eos>']:breakreturn ' '.join([idx2tgt_word[idx] for idx in tgt.squeeze().tolist() if idx not in [tgt_word2idx['<sos>'], tgt_word2idx['<eos>']]])# 测试翻译
print(translate(model, "i am a student"))
1. 训练阶段:Teacher Forcing

在训练阶段,模型已经知道目标句子(法语句子),因此可以直接将 目标句子中的真实单词 作为输入提供给解码器。这就是所谓的 Teacher Forcing 技术,它加速了模型的训练过程。

例如:

  • 对于句子 “je suis un étudiant”(目标语言),在训练时,解码器每一步的输入都使用了这个句子中 真实的单词,而不是让模型根据前面生成的单词去预测下一个词。

在训练过程中,模型的解码器的 forward 函数接收到完整的目标序列(除了最后的 <eos> 标记),然后模型通过这个序列生成预测的输出。这种方式在训练时非常有效,因为它让模型更快地学习目标句子的结构。

因此,在训练时,目标句子作为解码器的输入,无需逐步生成:

def forward(self, src, tgt):encoder_output = self.encoder(src)output = self.decoder(tgt, encoder_output)  # 直接使用真实目标序列作为输入return output

在这段代码中,tgt 是目标句子的真实词汇(去掉了 <eos> 标记),直接作为输入传递给解码器。而训练过程中,我们不需要像在推理时那样,每次都手动指定解码器的输入,因为我们直接给出了整个目标序列。

2. 推理阶段:逐步生成

在推理阶段(translate 函数中),没有真实的目标句子,因此模型必须逐步生成目标语言句子。

  • 在翻译的第一步,解码器只知道开始标记 <sos>,于是生成第一个单词。
  • 接下来,解码器使用前一步生成的单词作为输入,结合编码器的输出,生成第二个单词。
  • 这个过程会一直重复,直到生成结束标记 <eos> 或者达到最大句子长度。

这是一个 自回归生成 过程,因为解码器生成当前单词时,需要依赖前面已经生成的单词。

for _ in range(10):output = model.decoder(tgt, encoder_output)  # 使用当前的 tgt 来生成下一个单词pred_token = output.argmax(2)[:, -1].item()  # 取出最后一个时间步的预测结果tgt = torch.cat([tgt, torch.tensor([[pred_token]])], dim=1)  # 将最新生成的单词添加到 tgt 中if pred_token == tgt_word2idx['<eos>']:  # 如果生成了结束符 <eos>,停止生成break

在这段代码中,我们使用逐步生成的方式,逐步将最新生成的单词作为解码器的输入,以模拟实际的生成过程。

六、Self-Attention普通的Attention比较

Self-Attention普通的Attention机制有相似之处,但也存在显著的区别,尤其是在它们的使用场景和关注目标上。以下是它们的主要区别和相似点:

1. 关注目标不同
  • Self-Attention(自注意力):
    Self-Attention机制中的QueryKeyValue 都来自同一个输入序列。在处理输入序列的每个位置时,它计算该位置与序列中所有其他位置的关系。因此,Self-Attention用于一个序列内部,不依赖于外部的输入或输出序列。例如,在Transformer的编码器部分,输入序列的每个词都通过Self-Attention机制与整个输入句子中的其他词进行交互。

    场景:在处理单个序列时,Self-Attention用于捕捉序列内部不同位置的依赖关系,通常用于编码器和解码器内部,尤其在自然语言处理(NLP)任务中。

  • 普通的Attention
    普通的Attention机制通常是指序列到序列模型(Seq2Seq)中的Encoder-Decoder Attention。在这个过程中,解码器的每个时刻的输出作为Query,编码器输出的隐藏状态作为KeyValue。普通的Attention用于让解码器在生成输出时,能够“关注”编码器输出的某些部分,从而获取上下文信息。这里的Query和Key/Value来自不同的序列,通常用于跨序列之间的信息传递。

    场景:在处理多个序列时,普通Attention用于连接输入序列(如源语言)和输出序列(如目标语言),这在机器翻译等Seq2Seq任务中尤为常见。

2. 结构和用法不同
  • Self-Attention
    Self-Attention的输入来自同一个序列。它将输入序列的每个元素都作为Query,同时计算该元素与序列中所有元素(包括它自己)的关系。Self-Attention机制能够让模型捕捉到序列内部各个元素之间的长距离依赖性。

    公式
    在这里插入图片描述
    其中,Query (Q)、Key (K)、Value (V) 都来自相同的输入序列 X,即 Q = K = V = X。

  • 普通的Attention
    普通的Attention机制通常发生在编码器和解码器之间。在序列到序列模型中,解码器每一步的隐藏状态作为Query,而编码器输出的所有隐藏状态作为Key和Value。每一个解码步骤都会基于解码器的当前状态查询整个编码器的输出,从中选择最相关的信息。

    公式
    在这里插入图片描述
    其中,Query (Q) 来自解码器,Key (K) 和 Value (V) 来自编码器的输出。

3. 应用场景不同
  • Self-Attention
    Self-Attention的典型应用是在处理单个序列中的上下文关系。例如,在机器翻译任务中,Transformer模型的编码器通过Self-Attention机制,让每个词与输入句子的所有词进行关联,从而理解整个句子。它被广泛用于自然语言处理中的模型,如BERT、GPT等。

    应用示例

    • Transformer的编码器:每个词通过Self-Attention与其他词进行交互,以捕捉整个输入序列的上下文。
    • BERT模型:利用Self-Attention在句子内部的不同词之间建立联系,从而进行更好的文本理解。
  • 普通的Attention
    普通的Attention主要用于处理跨序列的上下文关系,例如机器翻译任务中,解码器需要生成目标句子中的每个词,而生成每个词时,解码器通过Attention机制选择与当前生成最相关的源句子部分。这个Attention过程通过让解码器“关注”编码器输出的某些部分,生成相应的翻译。

    应用示例

    • 机器翻译中的Seq2Seq模型:解码器在生成目标语言的词时,利用普通Attention机制在源语言的序列中选择最相关的信息。
    • 图像描述生成(Image Captioning):在生成描述时,普通Attention机制帮助模型选择图像中最相关的区域来生成自然语言描述。
4. 捕捉信息的粒度
  • Self-Attention
    Self-Attention机制擅长捕捉序列内部的全局信息,通过关注每个元素与其他元素之间的关系,它能够处理序列中长距离的依赖关系,尤其适合处理长序列任务。

  • 普通的Attention
    普通的Attention更多用于捕捉跨序列之间的信息关联,比如在机器翻译中,从源序列中挑选出与目标序列生成最相关的部分,因此关注的是两个序列之间的联系。

5. 并行化能力
  • Self-Attention
    Self-Attention的一个优势是它能够很好地并行化,因为每个输入位置可以同时计算其与序列中其他位置的关系,这使得模型在训练时能够更有效率地处理长序列数据。

  • 普通的Attention
    普通的Attention在解码器中通常是逐步生成目标序列的每个词,这使得解码过程难以完全并行化。不过,在编码器部分使用Attention时,仍然可以通过并行化的方式处理整个源序列。

6. 时间复杂度
  • Self-Attention
    Self-Attention的计算复杂度是O(n^2),其中n是序列长度,因为每个词需要与其他所有词计算相似度。这在处理长序列时可能会带来计算开销。

  • 普通的Attention
    普通的Attention在解码阶段,Query的数量相对较少,计算复杂度一般为O(nm),其中n是源序列的长度,m是目标序列的长度。这在目标序列较短时相对计算量较少。

总结
特性Self-Attention普通的Attention
输入同一个序列(Query = Key = Value)不同的序列(Query ≠ Key ≠ Value)
应用场景同一序列内部的依赖建模序列间的信息传递
典型任务编码器内部、BERT、GPT机器翻译、图像描述生成
关注点序列内部的全局信息跨序列的信息关联
并行化高效并行解码阶段难以并行
时间复杂度O(n^2)O(nm)

Self-Attention通常用于建模单个序列中的元素间的依赖,而普通的Attention则用于建模两个序列之间的依赖。在Transformer架构中,Self-Attention通常在编码器内部使用,而普通的Attention用于解码器连接编码器的输出,从而生成目标序列。


http://www.ppmy.cn/devtools/126309.html

相关文章

2012年国赛高教杯数学建模A题葡萄酒的评价解题全过程文档及程序

2012年国赛高教杯数学建模 A题 葡萄酒的评价 确定葡萄酒质量时一般是通过聘请一批有资质的评酒员进行品评。每个评酒员在对葡萄酒进行品尝后对其分类指标打分&#xff0c;然后求和得到其总分&#xff0c;从而确定葡萄酒的质量。酿酒葡萄的好坏与所酿葡萄酒的质量有直接的关系&…

自动驾驶高频面试题及答案

目录 高频面试题及答案1. 什么是自动驾驶?2. 自动驾驶的主要传感器有哪些?3. 自动驾驶中的感知与决策有什么区别?4. 解释一下自动驾驶的等级划分。5. 如何处理自动驾驶中的安全性问题?6. 自动驾驶车辆如何实现环境感知?7. 在自动驾驶中,如何处理车辆之间的通信?8. 自动驾…

【内网映射】frps实现内网映射

1. 简介 在当今互联网时代,远程访问内网资源已成为一种常见需求。无论是在家访问办公室的电脑,还是远程管理家庭NAS,内网映射都是一种强大的解决方案。 本文将详细介绍如何使用frp(Fast Reverse Proxy)来实现这一目标。 1.1 frp frp是一个高性能的反向代理应用,可以帮助您轻…

Python爬虫高效数据爬取方法

大家好!今天我们来聊聊Python爬虫中那些既简洁又高效的数据爬取方法。作为一名爬虫工程师,我们总是希望用最少的代码完成最多的工作。下面我ll分享一些在使用requests库进行网络爬虫时常用且高效的函数和方法。 1. requests.get() - 简单而强大 requests.get()是我们最常用的…

矩阵相关算法

矩阵旋转90度 给定一个 n n 的二维矩阵 matrix 表示一个图像&#xff0c;请你将图像顺时针旋转 90 度。 #include <iostream> #include <vector>using namespace std;void rotate(vector<vector<int>>& matrix) {int n matrix.size();// 第一步…

Three.js 快速入门 --- 鼠标操作三维场景

1、准备工作 需要引入 OrbitControls.js <script src"./three.js-r102/examples/js/controls/OrbitControls.js"></script>2、代码实现 function render() {renderer.render(scene,camera);//执行渲染操作 } render(); var controls new THREE.OrbitC…

全面掌握 Linux 服务管理:从入门到精通

全面掌握 Linux 服务管理&#xff1a;从入门到精通 引言 在 Linux 系统中&#xff0c;服务管理是系统管理员和开发者的基本技能之一。无论是启动、停止、重启还是查看服务状态&#xff0c;systemctl 命令都能让你轻松完成这些操作。今天&#xff0c;我们将深入探讨如何使用 sy…

系统架构设计师:数据库系统相关考题预测

作为系统架构设计师,在准备数据库系统相关的考试时,可以预期到的一些关键知识点包括但不限于以下几个方面: 数据库类型: 关系型数据库(RDBMS)与非关系型数据库(NoSQL)的区别及其适用场景。数据库管理系统(DBMS)的功能及组成部分。数据模型: 如何设计ER模型(实体-关…