Self-Attention 机制(Self-Attention Mechanism)是近年来深度学习中,特别是自然语言处理(NLP)任务中广泛应用的一种机制,它最初出现在Transformer架构中,用于捕捉序列数据中的依赖关系。该机制的核心思想是:通过给定序列的所有元素之间分配不同的权重,模型能够灵活地关注序列中的相关部分,从而更好地理解上下文关系。
一、背景
传统的序列处理模型,如RNN(循环神经网络)或LSTM(长短期记忆网络),通常处理序列的方式是依赖于时间步(timestep)上的顺序信息。这意味着模型在捕捉长距离依赖(long-range dependencies)时,效率较低。而Self-Attention则通过全局的方式来捕捉序列中任意两个元素之间的关系,不需要按照时间步逐一处理。
二、Self-Attention 的工作原理
Self-Attention 通过计算 Query、Key、Value,让序列中的每个元素与序列的其他元素进行比较,计算它们的相似性。
然后,Self-Attention 根据这个相似性来更新每个元素的表示,使得每个元素的表示不再只是它自己的信息,而是结合了与序列中其他元素的关联信息。
假设输入序列为 {x1,x2,…,xn},每个元素 xi 是一个向量。
具体步骤如下:
-
计算 Query、Key 和 Value
- 对每个输入 xi,我们生成三个向量:Query (Q),Key (K),和 Value (V)。
Query (Q) 表示当前元素“想要询问什么”,Key (K) 表示当前元素“可以被询问什么”,Value (V) 承载实际的输入信息。
这些向量由输入向量通过三组不同的可学习权重矩阵线性变换得到:
其中,Wq、Wk、Wv 是可学习的权重矩阵。
- 对每个输入 xi,我们生成三个向量:Query (Q),Key (K),和 Value (V)。
-
计算注意力得分
- 对第 i 个元素(第 i 个 Query 向量 Qi)来说,要与所有的 Key 向量进行相似性计算,即要计算 Qi与所有 Kj(j=1,2,…,n)之间的相似性得分。通过点积来实现:
这里,dk 是 Key 向量的维度。缩放因子 是为了稳定训练过程,防止值过大。
这些得分反映了输入序列中不同元素之间的相关性。
- 对第 i 个元素(第 i 个 Query 向量 Qi)来说,要与所有的 Key 向量进行相似性计算,即要计算 Qi与所有 Kj(j=1,2,…,n)之间的相似性得分。通过点积来实现:
-
Softmax 归一化
-
计算完所有的相似度后,通过 Softmax 函数对这些得分进行归一化,得到权重分布:
这一步可以确保所有权重的和为 1。
-
-
加权求和得到输出
-
最后,使用这些归一化后的权重 αij 对 Value 向量 Vj 进行加权求和,得到更新后的表示:
-
这样,每个输入 xi 的输出不仅仅依赖于它自身,还依赖于与其他所有输入的关系权重。这使得模型能够更好地捕捉输入序列中远程词语之间的关系。
三、多头注意力机制
为了让模型在不同的子空间上捕捉不同的注意力模式,Multi-Head Attention(多头注意力)机制进一步扩展了 Self-Attention。具体来说,多头注意力会将输入的 Query、Key 和 Value 向量分成多个头(head),在每个头上分别执行注意力操作,最后将各头的输出进行拼接并投影回原空间。
这种方法使模型可以从多个角度理解输入序列中的不同部分,从而增强模型的表达能力。
四、Self-Attention 的优势
- 并行化计算:不同于RNN等需要逐步处理输入序列,Self-Attention 允许并行计算,使得训练速度大大提升。
- 捕捉长距离依赖关系:由于每个输入都可以关注整个序列的所有其他元素,Self-Attention 机制能够很好地捕捉长距离的依赖关系。
- 可解释性:Attention 权重可以直观地解释模型关注了哪些输入,这为模型的决策过程提供了可解释性。
五、应用
Self-Attention 机制是 Transformer 模型的核心,并且广泛应用于多种任务中,包括:
- 机器翻译:如 Google 的 Transformer 模型被广泛应用于机器翻译任务中。
- 文本生成:如 GPT 系列模型中的自回归 Transformer 结构。
- 文本分类和问答系统:BERT 等双向 Transformer 结构也基于 Self-Attention,广泛用于各种 NLP 任务。
总的来说,Self-Attention 机制极大地推动了深度学习领域,特别是自然语言处理技术的发展。
实现一个简化的 Self-Attention Transformer 机器翻译模型,它没有包括多头注意力(Multi-Head Attention)、位置编码(Positional Encoding)等功能:
import torch
import torch.nn as nn
import torch.optim as optim# 定义词汇表
vocab_src = ['<pad>', '<sos>', '<eos>', 'i', 'am', 'a', 'student', 'he', 'is', 'teacher', 'she', 'loves', 'apples', 'we', 'are', 'friends']
vocab_tgt = ['<pad>', '<sos>', '<eos>', 'je', 'suis', 'un', 'étudiant', 'il', 'est', 'professeur', 'elle', 'aime', 'les', 'pommes', 'nous', 'sommes', 'amis']# 创建词汇表映射
src_vocab_size = len(vocab_src)
tgt_vocab_size = len(vocab_tgt)
src_word2idx = {word: idx for idx, word in enumerate(vocab_src)}
tgt_word2idx = {word: idx for idx, word in enumerate(vocab_tgt)}
idx2tgt_word = {idx: word for word, idx in tgt_word2idx.items()}# 句子对
pairs = [["i am a student", "je suis un étudiant"],["he is a teacher", "il est un professeur"],["she loves apples", "elle aime les pommes"],["we are friends", "nous sommes amis"]
]# 将句子转换为索引
def sentence_to_idx(sentence, word2idx):return [word2idx[word] for word in sentence.split()]# 模型参数
embedding_dim = 128
hidden_dim = 128
num_layers = 1# 自注意力层
class SelfAttention(nn.Module):def __init__(self, dim):super(SelfAttention, self).__init__()self.query = nn.Linear(dim, dim)self.key = nn.Linear(dim, dim)self.value = nn.Linear(dim, dim)self.scale = dim ** 0.5def forward(self, x):Q = self.query(x)K = self.key(x)V = self.value(x)attn_weights = torch.softmax(torch.bmm(Q, K.transpose(1, 2)) / self.scale, dim=-1)output = torch.bmm(attn_weights, V)return output# 编码器
class Encoder(nn.Module):def __init__(self, input_dim, embedding_dim, hidden_dim):super(Encoder, self).__init__()self.embedding = nn.Embedding(input_dim, embedding_dim)self.attention = SelfAttention(embedding_dim)def forward(self, src):embedded = self.embedding(src)attn_output = self.attention(embedded)return attn_output# 解码器
class Decoder(nn.Module):def __init__(self, output_dim, embedding_dim, hidden_dim):super(Decoder, self).__init__()self.embedding = nn.Embedding(output_dim, embedding_dim)self.attention = SelfAttention(embedding_dim)self.fc = nn.Linear(embedding_dim, output_dim)def forward(self, tgt, encoder_output):embedded = self.embedding(tgt)attn_output = self.attention(embedded)output = self.fc(attn_output)return output# Seq2Seq 模型
class Seq2Seq(nn.Module):def __init__(self, encoder, decoder):super(Seq2Seq, self).__init__()self.encoder = encoderself.decoder = decoderdef forward(self, src, tgt):encoder_output = self.encoder(src)output = self.decoder(tgt, encoder_output)return output# 数据处理
def prepare_data(pairs):src_data = [torch.tensor(sentence_to_idx(pair[0], src_word2idx), dtype=torch.long) for pair in pairs]tgt_data = [torch.tensor(sentence_to_idx(pair[1], tgt_word2idx), dtype=torch.long) for pair in pairs]return src_data, tgt_data# 初始化模型
encoder = Encoder(src_vocab_size, embedding_dim, hidden_dim)
decoder = Decoder(tgt_vocab_size, embedding_dim, hidden_dim)
model = Seq2Seq(encoder, decoder)# 损失函数和优化器
criterion = nn.CrossEntropyLoss(ignore_index=src_word2idx['<pad>'])
optimizer = optim.Adam(model.parameters(), lr=0.001)# 训练模型
def train(model, src_data, tgt_data, epochs=100):model.train()for epoch in range(epochs):total_loss = 0for src, tgt in zip(src_data, tgt_data):src = src.unsqueeze(0) # (1, seq_len)tgt_input = tgt[:-1].unsqueeze(0) # 输入解码器,去掉 <eos>tgt_output = tgt[1:].unsqueeze(0) # 解码器目标,去掉 <sos>optimizer.zero_grad()output = model(src, tgt_input)loss = criterion(output.view(-1, tgt_vocab_size), tgt_output.view(-1))loss.backward()optimizer.step()total_loss += loss.item()if (epoch + 1) % 10 == 0:print(f'Epoch {epoch+1}/{epochs}, Loss: {total_loss/len(src_data):.4f}')# 数据准备
src_data, tgt_data = prepare_data(pairs)# 训练模型
train(model, src_data, tgt_data, epochs=100)# 翻译函数
def translate(model, sentence):model.eval()src = torch.tensor(sentence_to_idx(sentence, src_word2idx), dtype=torch.long).unsqueeze(0)tgt = torch.tensor([tgt_word2idx['<sos>']], dtype=torch.long).unsqueeze(0)with torch.no_grad():encoder_output = model.encoder(src)for _ in range(10): # 假设最大输出句长为10output = model.decoder(tgt, encoder_output)pred_token = output.argmax(2)[:, -1].item()tgt = torch.cat([tgt, torch.tensor([[pred_token]])], dim=1)if pred_token == tgt_word2idx['<eos>']:breakreturn ' '.join([idx2tgt_word[idx] for idx in tgt.squeeze().tolist() if idx not in [tgt_word2idx['<sos>'], tgt_word2idx['<eos>']]])# 测试翻译
print(translate(model, "i am a student"))
1. 训练阶段:Teacher Forcing
在训练阶段,模型已经知道目标句子(法语句子),因此可以直接将 目标句子中的真实单词 作为输入提供给解码器。这就是所谓的 Teacher Forcing 技术,它加速了模型的训练过程。
例如:
- 对于句子 “je suis un étudiant”(目标语言),在训练时,解码器每一步的输入都使用了这个句子中 真实的单词,而不是让模型根据前面生成的单词去预测下一个词。
在训练过程中,模型的解码器的 forward
函数接收到完整的目标序列(除了最后的 <eos>
标记),然后模型通过这个序列生成预测的输出。这种方式在训练时非常有效,因为它让模型更快地学习目标句子的结构。
因此,在训练时,目标句子作为解码器的输入,无需逐步生成:
def forward(self, src, tgt):encoder_output = self.encoder(src)output = self.decoder(tgt, encoder_output) # 直接使用真实目标序列作为输入return output
在这段代码中,tgt
是目标句子的真实词汇(去掉了 <eos>
标记),直接作为输入传递给解码器。而训练过程中,我们不需要像在推理时那样,每次都手动指定解码器的输入,因为我们直接给出了整个目标序列。
2. 推理阶段:逐步生成
在推理阶段(translate
函数中),没有真实的目标句子,因此模型必须逐步生成目标语言句子。
- 在翻译的第一步,解码器只知道开始标记
<sos>
,于是生成第一个单词。 - 接下来,解码器使用前一步生成的单词作为输入,结合编码器的输出,生成第二个单词。
- 这个过程会一直重复,直到生成结束标记
<eos>
或者达到最大句子长度。
这是一个 自回归生成 过程,因为解码器生成当前单词时,需要依赖前面已经生成的单词。
for _ in range(10):output = model.decoder(tgt, encoder_output) # 使用当前的 tgt 来生成下一个单词pred_token = output.argmax(2)[:, -1].item() # 取出最后一个时间步的预测结果tgt = torch.cat([tgt, torch.tensor([[pred_token]])], dim=1) # 将最新生成的单词添加到 tgt 中if pred_token == tgt_word2idx['<eos>']: # 如果生成了结束符 <eos>,停止生成break
在这段代码中,我们使用逐步生成的方式,逐步将最新生成的单词作为解码器的输入,以模拟实际的生成过程。
六、Self-Attention和普通的Attention比较
Self-Attention和普通的Attention机制有相似之处,但也存在显著的区别,尤其是在它们的使用场景和关注目标上。以下是它们的主要区别和相似点:
1. 关注目标不同
-
Self-Attention(自注意力):
Self-Attention机制中的Query、Key、Value 都来自同一个输入序列。在处理输入序列的每个位置时,它计算该位置与序列中所有其他位置的关系。因此,Self-Attention用于一个序列内部,不依赖于外部的输入或输出序列。例如,在Transformer的编码器部分,输入序列的每个词都通过Self-Attention机制与整个输入句子中的其他词进行交互。场景:在处理单个序列时,Self-Attention用于捕捉序列内部不同位置的依赖关系,通常用于编码器和解码器内部,尤其在自然语言处理(NLP)任务中。
-
普通的Attention:
普通的Attention机制通常是指序列到序列模型(Seq2Seq)中的Encoder-Decoder Attention。在这个过程中,解码器的每个时刻的输出作为Query,编码器输出的隐藏状态作为Key和Value。普通的Attention用于让解码器在生成输出时,能够“关注”编码器输出的某些部分,从而获取上下文信息。这里的Query和Key/Value来自不同的序列,通常用于跨序列之间的信息传递。场景:在处理多个序列时,普通Attention用于连接输入序列(如源语言)和输出序列(如目标语言),这在机器翻译等Seq2Seq任务中尤为常见。
2. 结构和用法不同
-
Self-Attention:
Self-Attention的输入来自同一个序列。它将输入序列的每个元素都作为Query,同时计算该元素与序列中所有元素(包括它自己)的关系。Self-Attention机制能够让模型捕捉到序列内部各个元素之间的长距离依赖性。公式:
其中,Query (Q)、Key (K)、Value (V) 都来自相同的输入序列 X,即 Q = K = V = X。 -
普通的Attention:
普通的Attention机制通常发生在编码器和解码器之间。在序列到序列模型中,解码器每一步的隐藏状态作为Query,而编码器输出的所有隐藏状态作为Key和Value。每一个解码步骤都会基于解码器的当前状态查询整个编码器的输出,从中选择最相关的信息。公式:
其中,Query (Q) 来自解码器,Key (K) 和 Value (V) 来自编码器的输出。
3. 应用场景不同
-
Self-Attention:
Self-Attention的典型应用是在处理单个序列中的上下文关系。例如,在机器翻译任务中,Transformer模型的编码器通过Self-Attention机制,让每个词与输入句子的所有词进行关联,从而理解整个句子。它被广泛用于自然语言处理中的模型,如BERT、GPT等。应用示例:
- Transformer的编码器:每个词通过Self-Attention与其他词进行交互,以捕捉整个输入序列的上下文。
- BERT模型:利用Self-Attention在句子内部的不同词之间建立联系,从而进行更好的文本理解。
-
普通的Attention:
普通的Attention主要用于处理跨序列的上下文关系,例如机器翻译任务中,解码器需要生成目标句子中的每个词,而生成每个词时,解码器通过Attention机制选择与当前生成最相关的源句子部分。这个Attention过程通过让解码器“关注”编码器输出的某些部分,生成相应的翻译。应用示例:
- 机器翻译中的Seq2Seq模型:解码器在生成目标语言的词时,利用普通Attention机制在源语言的序列中选择最相关的信息。
- 图像描述生成(Image Captioning):在生成描述时,普通Attention机制帮助模型选择图像中最相关的区域来生成自然语言描述。
4. 捕捉信息的粒度
-
Self-Attention:
Self-Attention机制擅长捕捉序列内部的全局信息,通过关注每个元素与其他元素之间的关系,它能够处理序列中长距离的依赖关系,尤其适合处理长序列任务。 -
普通的Attention:
普通的Attention更多用于捕捉跨序列之间的信息关联,比如在机器翻译中,从源序列中挑选出与目标序列生成最相关的部分,因此关注的是两个序列之间的联系。
5. 并行化能力
-
Self-Attention:
Self-Attention的一个优势是它能够很好地并行化,因为每个输入位置可以同时计算其与序列中其他位置的关系,这使得模型在训练时能够更有效率地处理长序列数据。 -
普通的Attention:
普通的Attention在解码器中通常是逐步生成目标序列的每个词,这使得解码过程难以完全并行化。不过,在编码器部分使用Attention时,仍然可以通过并行化的方式处理整个源序列。
6. 时间复杂度
-
Self-Attention:
Self-Attention的计算复杂度是O(n^2),其中n是序列长度,因为每个词需要与其他所有词计算相似度。这在处理长序列时可能会带来计算开销。 -
普通的Attention:
普通的Attention在解码阶段,Query的数量相对较少,计算复杂度一般为O(nm),其中n是源序列的长度,m是目标序列的长度。这在目标序列较短时相对计算量较少。
总结
特性 | Self-Attention | 普通的Attention |
---|---|---|
输入 | 同一个序列(Query = Key = Value) | 不同的序列(Query ≠ Key ≠ Value) |
应用场景 | 同一序列内部的依赖建模 | 序列间的信息传递 |
典型任务 | 编码器内部、BERT、GPT | 机器翻译、图像描述生成 |
关注点 | 序列内部的全局信息 | 跨序列的信息关联 |
并行化 | 高效并行 | 解码阶段难以并行 |
时间复杂度 | O(n^2) | O(nm) |
Self-Attention通常用于建模单个序列中的元素间的依赖,而普通的Attention则用于建模两个序列之间的依赖。在Transformer架构中,Self-Attention通常在编码器内部使用,而普通的Attention用于解码器连接编码器的输出,从而生成目标序列。