文章目录
- 介绍
- Bahdanau 注意力(Bahdanau Attention)
- 原理
- 公式含义
- 计算过程
- 编码器部分
- 注意力机制部分
- 解码器部分
- 计算过程
- 代码实现
- 导包
- 定义注意力解码器
- 添加Bahdanau的decoder
- 训练
- 评估指标 bleu
- 开始预测
个人主页:道友老李
欢迎加入社区:道友老李的学习社区
介绍
**自然语言处理(Natural Language Processing,NLP)**是计算机科学领域与人工智能领域中的一个重要方向。它研究的是人类(自然)语言与计算机之间的交互。NLP的目标是让计算机能够理解、解析、生成人类语言,并且能够以有意义的方式回应和操作这些信息。
NLP的任务可以分为多个层次,包括但不限于:
- 词法分析:将文本分解成单词或标记(token),并识别它们的词性(如名词、动词等)。
- 句法分析:分析句子结构,理解句子中词语的关系,比如主语、谓语、宾语等。
- 语义分析:试图理解句子的实际含义,超越字面意义,捕捉隐含的信息。
- 语用分析:考虑上下文和对话背景,理解话语在特定情境下的使用目的。
- 情感分析:检测文本中表达的情感倾向,例如正面、负面或中立。
- 机器翻译:将一种自然语言转换为另一种自然语言。
- 问答系统:构建可以回答用户问题的系统。
- 文本摘要:从大量文本中提取关键信息,生成简短的摘要。
- 命名实体识别(NER):识别文本中提到的特定实体,如人名、地名、组织名等。
- 语音识别:将人类的语音转换为计算机可读的文字格式。
NLP技术的发展依赖于算法的进步、计算能力的提升以及大规模标注数据集的可用性。近年来,深度学习方法,特别是基于神经网络的语言模型,如BERT、GPT系列等,在许多NLP任务上取得了显著的成功。随着技术的进步,NLP正在被应用到越来越多的领域,包括客户服务、智能搜索、内容推荐、医疗健康等。
Bahdanau 注意力(Bahdanau Attention)
Bahdanau注意力(Bahdanau Attention)是自然语言处理中一种经典的注意力机制。
在传统的编码器 - 解码器架构(如基于RNN的架构)中,编码器将整个输入序列编码为一个固定长度的向量,解码器依赖该向量生成输出。当输入序列较长时,这种固定长度向量难以存储所有重要信息,导致性能下降。Bahdanau注意力机制通过让解码器在生成每个输出时,动态关注输入序列不同部分,解决此问题。
原理
允许解码器在生成输出时,根据当前状态,从编码器的隐藏状态序列中选择性聚焦,获取与当前生成任务最相关信息,而非仅依赖单一固定向量。
Bahdanau 注意力机制中计算上下文向量的公式:
公式含义
- c t c_t ct 表示在解码器的时间步 t t t 时得到的上下文向量,它综合了编码器隐藏状态序列中的信息,用于辅助解码器在该时间步生成输出。
- T T T 是编码器的时间步总数,意味着要考虑编码器所有时间步的隐藏状态。
- α ( s t − 1 , h i ) \alpha(s_{t - 1}, h_i) α(st−1,hi) 是注意力权重,它表示在解码器时间步 t − 1 t - 1 t−1 的隐藏状态 s t − 1 s_{t - 1} st−1 条件下,对编码器第 i i i 个时间步隐藏状态 h i h_i hi 的关注程度。这个权重是通过一个特定的计算(通常涉及一个小型神经网络来计算相似度等)得到,并经过softmax函数归一化,取值范围在 0 0 0 到 1 1 1 之间,且 ∑ i = 1 T α ( s t − 1 , h i ) = 1 \sum_{i = 1}^{T}\alpha(s_{t - 1}, h_i)=1 ∑i=1Tα(st−1,hi)=1。
- h i h_i hi 是编码器在第 i i i 个时间步的隐藏状态,它包含了输入序列在该时间步及之前的信息。
计算过程
- 首先,根据解码器上一个时间步的隐藏状态 s t − 1 s_{t - 1} st−1 和编码器所有时间步的隐藏状态 h i h_i hi( i i i 从 1 1 1 到 T T T),计算出每个 h i h_i hi 对应的注意力权重 α ( s t − 1 , h i ) \alpha(s_{t - 1}, h_i) α(st−1,hi)。
- 然后,将这些注意力权重分别与对应的编码器隐藏状态 h i h_i hi 相乘,并对所有时间步的乘积结果进行求和,就得到了当前解码器时间步 t t t 的上下文向量 c t c_t ct。
这个上下文向量 c t c_t ct 后续会与解码器当前时间步 t t t 的隐藏状态等信息结合,用于生成当前时间步的输出,比如在机器翻译任务中预测目标语言的下一个单词。
一个带有Bahdanau注意力的循环神经网络编码器-解码器模型:
编码器部分
- 嵌入层:将源序列(如源语言句子中的单词)从离散的符号转换为低维、连续的向量表示,即词嵌入,便于模型后续处理,同时捕捉单词语义关系。
- 循环层:一般由RNN、LSTM或GRU等单元构成。按顺序处理嵌入层输出的向量序列,每个时间步结合当前输入和上一时刻隐藏状态更新隐藏状态,逐步将源序列信息编码到隐藏状态中,最终输出包含源序列语义信息的隐藏状态序列。
注意力机制部分
位于编码器和解码器之间,允许解码器在生成输出时,根据当前状态从编码器的隐藏状态序列中动态选择相关信息。它计算解码器当前隐藏状态与编码器各时间步隐藏状态的相关性,得到注意力权重,对编码器隐藏状态加权求和生成上下文向量,为解码器提供与当前生成任务相关的信息。
解码器部分
- 嵌入层:与编码器的嵌入层类似,将目标序列(如目标语言句子中的单词)的离散符号转换为向量表示,不过针对目标语言。
- 循环层:接收编码器输出的隐藏状态序列以及注意力机制生成的上下文向量,结合目标序列嵌入向量,按顺序处理并更新隐藏状态,生成目标序列下一个元素的预测。
- 全连接层:对循环层输出进行处理,将隐藏状态映射到目标词汇表维度,经softmax函数计算词汇表中每个单词的概率分布,预测当前时间步最可能的输出单词。
该架构在机器翻译、文本摘要等序列到序列任务中应用广泛,注意力机制可有效解决长序列信息处理难题,提升模型性能。
计算过程
- 计算注意力分数:解码器在时间步 t t t的隐藏状态 h t d e c h_t^{dec} htdec作为查询(query),与编码器所有时间步的隐藏状态 h i e n c h_i^{enc} hienc( i = 1 , ⋯ , T i = 1, \cdots, T i=1,⋯,T, T T T为编码器时间步数)计算注意力分数 e t , i e_{t,i} et,i,一般通过一个小型神经网络计算,如 e t , i = a ( h t d e c , h i e n c ) e_{t,i}=a(h_t^{dec}, h_i^{enc}) et,i=a(htdec,hienc), a a a是一个非线性函数。
- 归一化注意力分数:将注意力分数 e t , i e_{t,i} et,i通过softmax函数归一化,得到注意力权重 α t , i \alpha_{t,i} αt,i,即 α t , i = exp ( e t , i ) ∑ j = 1 T exp ( e t , j ) \alpha_{t,i}=\frac{\exp(e_{t,i})}{\sum_{j = 1}^{T}\exp(e_{t,j})} αt,i=∑j=1Texp(et,j)exp(et,i),表示编码器第 i i i个时间步对解码器当前时间步 t t t的重要程度。
- 计算上下文向量:根据注意力权重对编码器隐藏状态加权求和,得到上下文向量 c t c_t ct, c t = ∑ i = 1 T α t , i h i e n c c_t=\sum_{i = 1}^{T}\alpha_{t,i}h_i^{enc} ct=∑i=1Tαt,ihienc,它包含了与当前生成任务相关的输入信息。
- 生成输出:上下文向量 c t c_t ct与解码器当前隐藏状态 h t d e c h_t^{dec} htdec结合,如拼接后输入到后续网络层,生成当前时间步的输出。
代码实现
导包
import torch
from torch import nn
import dltools
定义注意力解码器
class AttentionDecoder(dltools.Decoder):def __init__(self, **kwargs):super().__init__(**kwargs)@propertydef attention_weights(self):raise NotImplementedError
添加Bahdanau的decoder
class Seq2SeqAttentionDecoder(AttentionDecoder):def __init__(self, vocab_size, embed_size, num_hiddens, num_layers, dropout=0, **kwargs):super().__init__(**kwargs)self.attention = dltools.AdditiveAttention(num_hiddens, num_hiddens, num_hiddens, dropout)self.embedding = nn.Embedding(vocab_size, embed_size)self.rnn = nn.GRU(embed_size + num_hiddens, num_hiddens, num_layers, dropout=dropout)self.dense = nn.Linear(num_hiddens, vocab_size)def init_state(self, enc_outputs, enc_valid_lens, *args):# outputs : (batch_size, num_steps, num_hiddens)# hidden_state: (num_layers, batch_size, num_hiddens)outputs, hidden_state = enc_outputsreturn (outputs.permute(1, 0, 2), hidden_state, enc_valid_lens)def forward(self, X, state):# enc_outputs (batch_size, num_steps, num_hiddens)# hidden_state: (num_layers, batch_size, num_hiddens)enc_outputs, hidden_state, enc_valid_lens = state# X : (batch_size, num_steps, vocab_size)X = self.embedding(X) # X : (batch_size, num_steps, embed_size)X = X.permute(1, 0, 2)outputs, self._attention_weights = [], []for x in X:query = torch.unsqueeze(hidden_state[-1], dim=1) # batch_size, 1, num_hiddens# print('query:', query.shape) # 4, 1, 16context = self.attention(query, enc_outputs, enc_outputs, enc_valid_lens)# print('context: ', context.shape)x = torch.cat((context, torch.unsqueeze(x, dim=1)), dim=-1)# print('x: ', x.shape)out, hidden_state = self.rnn(x.permute(1, 0, 2), hidden_state)# print('out:', out.shape)# print('hidden_state:', hidden_state.shape)outputs.append(out)self._attention_weights.append(self.attention_weights)# print('---------------------------------')outputs = self.dense(torch.cat(outputs, dim=0))# print('解码器最终输出形状: ', outputs.shape)return outputs.permute(1, 0, 2), [enc_outputs, hidden_state, enc_valid_lens]@propertydef attention_weights(self):return self._attention_weightsencoder = dltools.Seq2SeqEncoder(vocab_size=10, embed_size=8, num_hiddens=16, num_layers=2)
encoder.eval()
decoder = Seq2SeqAttentionDecoder(vocab_size=10, embed_size=8, num_hiddens=16, num_layers=2)
decoder.eval()# batch_size 4, num_steps 7
X = torch.zeros((4, 7), dtype=torch.long)
state = decoder.init_state(encoder(X), None)
output, state = decoder(X, state)
output.shape, len(state), state[0].shape, len(state[1]), state[1][0].shape
query: torch.Size([4, 1, 16])
context: torch.Size([4, 1, 16])
x: torch.Size([4, 1, 24])
out: torch.Size([1, 4, 16])
hidden_state: torch.Size([2, 4, 16])
query: torch.Size([4, 1, 16])
context: torch.Size([4, 1, 16])
x: torch.Size([4, 1, 24])
out: torch.Size([1, 4, 16])
hidden_state: torch.Size([2, 4, 16])
query: torch.Size([4, 1, 16])
context: torch.Size([4, 1, 16])
x: torch.Size([4, 1, 24])
out: torch.Size([1, 4, 16])
hidden_state: torch.Size([2, 4, 16])
query: torch.Size([4, 1, 16])
context: torch.Size([4, 1, 16])
x: torch.Size([4, 1, 24])
out: torch.Size([1, 4, 16])
hidden_state: torch.Size([2, 4, 16])
query: torch.Size([4, 1, 16])
context: torch.Size([4, 1, 16])
x: torch.Size([4, 1, 24])
out: torch.Size([1, 4, 16])
hidden_state: torch.Size([2, 4, 16])
query: torch.Size([4, 1, 16])
context: torch.Size([4, 1, 16])
x: torch.Size([4, 1, 24])
out: torch.Size([1, 4, 16])
hidden_state: torch.Size([2, 4, 16])
query: torch.Size([4, 1, 16])
context: torch.Size([4, 1, 16])
x: torch.Size([4, 1, 24])
out: torch.Size([1, 4, 16])
hidden_state: torch.Size([2, 4, 16])
---------------------------------
解码器最终输出形状: torch.Size([7, 4, 10])
(torch.Size([4, 7, 10]), 3, torch.Size([4, 7, 16]), 2, torch.Size([4, 16]))
训练
执行训练前,将decoder中的print屏蔽掉!!
embed_size, num_hiddens, num_layers, dropout = 32, 32, 2, 0.1
batch_size, num_steps = 64, 10
lr, num_epochs, device = 0.005, 200, dltools.try_gpu()train_iter, src_vocab, tgt_vocab = dltools.load_data_nmt(batch_size, num_steps)
encoder = dltools.Seq2SeqEncoder(len(src_vocab), embed_size, num_hiddens, num_layers, dropout)
decoder = Seq2SeqAttentionDecoder(len(tgt_vocab), embed_size, num_hiddens, num_layers, dropout)
net = dltools.EncoderDecoder(encoder, decoder)
dltools.train_seq2seq(net, train_iter, lr, num_epochs, tgt_vocab, device)
评估指标 bleu
def bleu(pred_seq, label_seq, k):print('pred_seq', pred_seq)print('label_seq:', label_seq)pred_tokens, label_tokens = pred_seq.split(' '), label_seq.split(' ')len_pred, len_label = len(pred_tokens), len(label_tokens)score = math.exp(min(0, 1 - (len_label / len_pred)))for n in range(1, k + 1):num_matches, label_subs = 0, collections.defaultdict(int)for i in range(len_label - n + 1):label_subs[' '.join(label_tokens[i: i + n])] += 1for i in range(len_pred - n + 1):if label_subs[' '.join(pred_tokens[i: i + n])] > 0:num_matches += 1label_subs[' '.join(pred_tokens[i: i + n])] -= 1score *= math.pow(num_matches / (len_pred - n + 1), math.pow(0.5, n)) return score
开始预测
engs = ['go .', 'i lost .', 'he\'s calm .', 'i\'m home .']
fras = ['va !', 'j\'ai perdu .', 'il est calme .', 'je suis chez moi .']
for eng, fra in zip(engs, fras):translation = dltools.predict_seq2seq(net, eng, src_vocab, tgt_vocab, num_steps, device)print(f'{eng} => {translation}, bleu {dltools.bleu(translation[0], fra, k=2):.3f}')
go . => ('va !', []), bleu 1.000
i lost . => ("j'ai perdu .", []), bleu 1.000
he's calm . => ('il est malade .', []), bleu 0.658
i'm home . => ('je suis chez moi .', []), bleu 1.000