【自然语言处理(NLP)】Bahdanau 注意力(Bahdanau Attention)原理及代码实现

ops/2025/1/30 21:50:48/

文章目录

  • 介绍
  • Bahdanau 注意力(Bahdanau Attention)
    • 原理
      • 公式含义
      • 计算过程
      • 编码器部分
      • 注意力机制部分
      • 解码器部分
    • 计算过程
    • 代码实现
      • 导包
      • 定义注意力解码器
      • 添加Bahdanau的decoder
      • 训练
      • 评估指标 bleu
      • 开始预测

个人主页:道友老李
欢迎加入社区:道友老李的学习社区

介绍

**自然语言处理(Natural Language Processing,NLP)**是计算机科学领域与人工智能领域中的一个重要方向。它研究的是人类(自然)语言与计算机之间的交互。NLP的目标是让计算机能够理解、解析、生成人类语言,并且能够以有意义的方式回应和操作这些信息。

NLP的任务可以分为多个层次,包括但不限于:

  1. 词法分析:将文本分解成单词或标记(token),并识别它们的词性(如名词、动词等)。
  2. 句法分析:分析句子结构,理解句子中词语的关系,比如主语、谓语、宾语等。
  3. 语义分析:试图理解句子的实际含义,超越字面意义,捕捉隐含的信息。
  4. 语用分析:考虑上下文和对话背景,理解话语在特定情境下的使用目的。
  5. 情感分析:检测文本中表达的情感倾向,例如正面、负面或中立。
  6. 机器翻译:将一种自然语言转换为另一种自然语言。
  7. 问答系统:构建可以回答用户问题的系统。
  8. 文本摘要:从大量文本中提取关键信息,生成简短的摘要。
  9. 命名实体识别(NER):识别文本中提到的特定实体,如人名、地名、组织名等。
  10. 语音识别:将人类的语音转换为计算机可读的文字格式。

NLP技术的发展依赖于算法的进步、计算能力的提升以及大规模标注数据集的可用性。近年来,深度学习方法,特别是基于神经网络的语言模型,如BERT、GPT系列等,在许多NLP任务上取得了显著的成功。随着技术的进步,NLP正在被应用到越来越多的领域,包括客户服务、智能搜索、内容推荐、医疗健康等。

Bahdanau 注意力(Bahdanau Attention)

Bahdanau注意力(Bahdanau Attention)是自然语言处理中一种经典的注意力机制。

在传统的编码器 - 解码器架构(如基于RNN的架构)中,编码器将整个输入序列编码为一个固定长度的向量,解码器依赖该向量生成输出。当输入序列较长时,这种固定长度向量难以存储所有重要信息,导致性能下降。Bahdanau注意力机制通过让解码器在生成每个输出时,动态关注输入序列不同部分,解决此问题。

原理

允许解码器在生成输出时,根据当前状态,从编码器的隐藏状态序列中选择性聚焦,获取与当前生成任务最相关信息,而非仅依赖单一固定向量。

Bahdanau 注意力机制中计算上下文向量的公式:
在这里插入图片描述

公式含义

  • c t c_t ct 表示在解码器的时间步 t t t 时得到的上下文向量,它综合了编码器隐藏状态序列中的信息,用于辅助解码器在该时间步生成输出。
  • T T T 是编码器的时间步总数,意味着要考虑编码器所有时间步的隐藏状态。
  • α ( s t − 1 , h i ) \alpha(s_{t - 1}, h_i) α(st1,hi) 是注意力权重,它表示在解码器时间步 t − 1 t - 1 t1 的隐藏状态 s t − 1 s_{t - 1} st1 条件下,对编码器第 i i i 个时间步隐藏状态 h i h_i hi 的关注程度。这个权重是通过一个特定的计算(通常涉及一个小型神经网络来计算相似度等)得到,并经过softmax函数归一化,取值范围在 0 0 0 1 1 1 之间,且 ∑ i = 1 T α ( s t − 1 , h i ) = 1 \sum_{i = 1}^{T}\alpha(s_{t - 1}, h_i)=1 i=1Tα(st1,hi)=1
  • h i h_i hi 是编码器在第 i i i 个时间步的隐藏状态,它包含了输入序列在该时间步及之前的信息。

计算过程

  • 首先,根据解码器上一个时间步的隐藏状态 s t − 1 s_{t - 1} st1 和编码器所有时间步的隐藏状态 h i h_i hi i i i 1 1 1 T T T),计算出每个 h i h_i hi 对应的注意力权重 α ( s t − 1 , h i ) \alpha(s_{t - 1}, h_i) α(st1,hi)
  • 然后,将这些注意力权重分别与对应的编码器隐藏状态 h i h_i hi 相乘,并对所有时间步的乘积结果进行求和,就得到了当前解码器时间步 t t t 的上下文向量 c t c_t ct

这个上下文向量 c t c_t ct 后续会与解码器当前时间步 t t t 的隐藏状态等信息结合,用于生成当前时间步的输出,比如在机器翻译任务中预测目标语言的下一个单词。

一个带有Bahdanau注意力的循环神经网络编码器-解码器模型:
在这里插入图片描述

编码器部分

  • 嵌入层:将源序列(如源语言句子中的单词)从离散的符号转换为低维、连续的向量表示,即词嵌入,便于模型后续处理,同时捕捉单词语义关系。
  • 循环层:一般由RNN、LSTM或GRU等单元构成。按顺序处理嵌入层输出的向量序列,每个时间步结合当前输入和上一时刻隐藏状态更新隐藏状态,逐步将源序列信息编码到隐藏状态中,最终输出包含源序列语义信息的隐藏状态序列。

注意力机制部分

位于编码器和解码器之间,允许解码器在生成输出时,根据当前状态从编码器的隐藏状态序列中动态选择相关信息。它计算解码器当前隐藏状态与编码器各时间步隐藏状态的相关性,得到注意力权重,对编码器隐藏状态加权求和生成上下文向量,为解码器提供与当前生成任务相关的信息。

解码器部分

  • 嵌入层:与编码器的嵌入层类似,将目标序列(如目标语言句子中的单词)的离散符号转换为向量表示,不过针对目标语言。
  • 循环层:接收编码器输出的隐藏状态序列以及注意力机制生成的上下文向量,结合目标序列嵌入向量,按顺序处理并更新隐藏状态,生成目标序列下一个元素的预测。
  • 全连接层:对循环层输出进行处理,将隐藏状态映射到目标词汇表维度,经softmax函数计算词汇表中每个单词的概率分布,预测当前时间步最可能的输出单词。

该架构在机器翻译、文本摘要等序列到序列任务中应用广泛,注意力机制可有效解决长序列信息处理难题,提升模型性能。

计算过程

  1. 计算注意力分数:解码器在时间步 t t t的隐藏状态 h t d e c h_t^{dec} htdec作为查询(query),与编码器所有时间步的隐藏状态 h i e n c h_i^{enc} hienc i = 1 , ⋯ , T i = 1, \cdots, T i=1,,T T T T为编码器时间步数)计算注意力分数 e t , i e_{t,i} et,i,一般通过一个小型神经网络计算,如 e t , i = a ( h t d e c , h i e n c ) e_{t,i}=a(h_t^{dec}, h_i^{enc}) et,i=a(htdec,hienc) a a a是一个非线性函数。
  2. 归一化注意力分数:将注意力分数 e t , i e_{t,i} et,i通过softmax函数归一化,得到注意力权重 α t , i \alpha_{t,i} αt,i,即 α t , i = exp ⁡ ( e t , i ) ∑ j = 1 T exp ⁡ ( e t , j ) \alpha_{t,i}=\frac{\exp(e_{t,i})}{\sum_{j = 1}^{T}\exp(e_{t,j})} αt,i=j=1Texp(et,j)exp(et,i),表示编码器第 i i i个时间步对解码器当前时间步 t t t的重要程度。
  3. 计算上下文向量:根据注意力权重对编码器隐藏状态加权求和,得到上下文向量 c t c_t ct c t = ∑ i = 1 T α t , i h i e n c c_t=\sum_{i = 1}^{T}\alpha_{t,i}h_i^{enc} ct=i=1Tαt,ihienc,它包含了与当前生成任务相关的输入信息。
  4. 生成输出:上下文向量 c t c_t ct与解码器当前隐藏状态 h t d e c h_t^{dec} htdec结合,如拼接后输入到后续网络层,生成当前时间步的输出。

代码实现

导包

import torch
from torch import nn
import dltools

定义注意力解码器

class AttentionDecoder(dltools.Decoder):def __init__(self, **kwargs):super().__init__(**kwargs)@propertydef attention_weights(self):raise NotImplementedError

添加Bahdanau的decoder

class Seq2SeqAttentionDecoder(AttentionDecoder):def __init__(self, vocab_size, embed_size, num_hiddens, num_layers, dropout=0, **kwargs):super().__init__(**kwargs)self.attention = dltools.AdditiveAttention(num_hiddens, num_hiddens, num_hiddens, dropout)self.embedding = nn.Embedding(vocab_size, embed_size)self.rnn = nn.GRU(embed_size + num_hiddens, num_hiddens, num_layers, dropout=dropout)self.dense = nn.Linear(num_hiddens, vocab_size)def init_state(self, enc_outputs, enc_valid_lens, *args):# outputs : (batch_size, num_steps, num_hiddens)# hidden_state: (num_layers, batch_size, num_hiddens)outputs, hidden_state = enc_outputsreturn (outputs.permute(1, 0, 2), hidden_state, enc_valid_lens)def forward(self, X, state):# enc_outputs (batch_size, num_steps, num_hiddens)# hidden_state: (num_layers, batch_size, num_hiddens)enc_outputs, hidden_state, enc_valid_lens = state# X : (batch_size, num_steps, vocab_size)X = self.embedding(X) # X : (batch_size, num_steps, embed_size)X = X.permute(1, 0, 2)outputs, self._attention_weights = [], []for x in X:query = torch.unsqueeze(hidden_state[-1], dim=1) # batch_size, 1, num_hiddens# print('query:', query.shape) # 4, 1, 16context = self.attention(query, enc_outputs, enc_outputs, enc_valid_lens)# print('context: ', context.shape)x = torch.cat((context, torch.unsqueeze(x, dim=1)), dim=-1)# print('x: ', x.shape)out, hidden_state = self.rnn(x.permute(1, 0, 2), hidden_state)# print('out:', out.shape)# print('hidden_state:', hidden_state.shape)outputs.append(out)self._attention_weights.append(self.attention_weights)# print('---------------------------------')outputs = self.dense(torch.cat(outputs, dim=0))# print('解码器最终输出形状: ', outputs.shape)return outputs.permute(1, 0, 2), [enc_outputs, hidden_state, enc_valid_lens]@propertydef attention_weights(self):return self._attention_weightsencoder = dltools.Seq2SeqEncoder(vocab_size=10, embed_size=8, num_hiddens=16, num_layers=2)
encoder.eval()
decoder = Seq2SeqAttentionDecoder(vocab_size=10, embed_size=8, num_hiddens=16, num_layers=2)
decoder.eval()# batch_size 4, num_steps 7
X = torch.zeros((4, 7), dtype=torch.long)
state = decoder.init_state(encoder(X), None)
output, state = decoder(X, state)
output.shape, len(state), state[0].shape, len(state[1]), state[1][0].shape
query: torch.Size([4, 1, 16])
context:  torch.Size([4, 1, 16])
x:  torch.Size([4, 1, 24])
out: torch.Size([1, 4, 16])
hidden_state: torch.Size([2, 4, 16])
query: torch.Size([4, 1, 16])
context:  torch.Size([4, 1, 16])
x:  torch.Size([4, 1, 24])
out: torch.Size([1, 4, 16])
hidden_state: torch.Size([2, 4, 16])
query: torch.Size([4, 1, 16])
context:  torch.Size([4, 1, 16])
x:  torch.Size([4, 1, 24])
out: torch.Size([1, 4, 16])
hidden_state: torch.Size([2, 4, 16])
query: torch.Size([4, 1, 16])
context:  torch.Size([4, 1, 16])
x:  torch.Size([4, 1, 24])
out: torch.Size([1, 4, 16])
hidden_state: torch.Size([2, 4, 16])
query: torch.Size([4, 1, 16])
context:  torch.Size([4, 1, 16])
x:  torch.Size([4, 1, 24])
out: torch.Size([1, 4, 16])
hidden_state: torch.Size([2, 4, 16])
query: torch.Size([4, 1, 16])
context:  torch.Size([4, 1, 16])
x:  torch.Size([4, 1, 24])
out: torch.Size([1, 4, 16])
hidden_state: torch.Size([2, 4, 16])
query: torch.Size([4, 1, 16])
context:  torch.Size([4, 1, 16])
x:  torch.Size([4, 1, 24])
out: torch.Size([1, 4, 16])
hidden_state: torch.Size([2, 4, 16])
---------------------------------
解码器最终输出形状:  torch.Size([7, 4, 10])
(torch.Size([4, 7, 10]), 3, torch.Size([4, 7, 16]), 2, torch.Size([4, 16]))

训练

执行训练前,将decoder中的print屏蔽掉!!

embed_size, num_hiddens, num_layers, dropout = 32, 32, 2, 0.1
batch_size, num_steps = 64, 10
lr, num_epochs, device = 0.005, 200, dltools.try_gpu()train_iter, src_vocab, tgt_vocab = dltools.load_data_nmt(batch_size, num_steps)
encoder = dltools.Seq2SeqEncoder(len(src_vocab), embed_size, num_hiddens, num_layers, dropout)
decoder = Seq2SeqAttentionDecoder(len(tgt_vocab), embed_size, num_hiddens, num_layers, dropout)
net = dltools.EncoderDecoder(encoder, decoder)
dltools.train_seq2seq(net, train_iter, lr, num_epochs, tgt_vocab, device)

在这里插入图片描述

在这里插入图片描述

在这里插入图片描述

评估指标 bleu

def bleu(pred_seq, label_seq, k):print('pred_seq', pred_seq)print('label_seq:', label_seq)pred_tokens, label_tokens = pred_seq.split(' '), label_seq.split(' ')len_pred, len_label = len(pred_tokens), len(label_tokens)score = math.exp(min(0, 1 - (len_label / len_pred)))for n in range(1, k + 1):num_matches, label_subs = 0, collections.defaultdict(int)for i in range(len_label - n + 1):label_subs[' '.join(label_tokens[i: i + n])] += 1for i in range(len_pred - n + 1):if label_subs[' '.join(pred_tokens[i: i + n])] > 0:num_matches += 1label_subs[' '.join(pred_tokens[i: i + n])] -= 1score *=  math.pow(num_matches / (len_pred - n + 1), math.pow(0.5, n))   return score

开始预测

engs = ['go .', 'i lost .', 'he\'s calm .', 'i\'m home .']
fras = ['va !', 'j\'ai perdu .', 'il est calme .', 'je suis chez moi .']
for eng, fra in zip(engs, fras):translation = dltools.predict_seq2seq(net, eng, src_vocab, tgt_vocab, num_steps, device)print(f'{eng} => {translation}, bleu {dltools.bleu(translation[0], fra, k=2):.3f}')
go . => ('va !', []), bleu 1.000
i lost . => ("j'ai perdu .", []), bleu 1.000
he's calm . => ('il est malade .', []), bleu 0.658
i'm home . => ('je suis chez moi .', []), bleu 1.000

http://www.ppmy.cn/ops/154304.html

相关文章

el-tree 父节点隐藏

这是我之前面试的一个题 让我写 如果你恰好也有这道题 希望可以帮到你 实现效果 <el-treenode-key"id"ref"tree"check-change"handleCheckChange":props"props":load"loadNode"lazyshow-checkbox //添加选择框>//深度…

深入MapReduce——计算模型设计

引入 通过引入篇&#xff0c;我们可以总结&#xff0c;MapReduce针对海量数据计算核心痛点的解法如下&#xff1a; 统一编程模型&#xff0c;降低用户使用门槛分而治之&#xff0c;利用了并行处理提高计算效率移动计算&#xff0c;减少硬件瓶颈的限制 优秀的设计&#xff0c…

上位机知识篇---Linux的shell脚本搜索、查找、管道

文章目录 前言第一部分&#xff1a;什么是shell&#xff1f;1. 基本结构脚本声明注释命令和表达式例子 2.变量控制结构条件判断 3.函数输入输出重定向 4.执行命令5.实际应用 第二部分&#xff1a;Linux的搜索、查找、管道命令1.搜索命令2.查找命令3.管道操作 总结 前言 以上就…

智能化加速标准和协议的更新并推动验证IP(VIP)在芯片设计中的更广泛应用

作者&#xff1a;Karthik Gopal, SmartDV Technologies亚洲区总经理 智权半导体科技&#xff08;厦门&#xff09;有限公司总经理 随着AI技术向边缘和端侧设备广泛渗透&#xff0c;芯片设计师不仅需要考虑在其设计中引入加速器&#xff0c;也在考虑采用速度更快和带宽更高的总…

Stable Diffusion 3.5 介绍

Stable Diffusion 3.5 是由 Stability AI 推出的最新一代图像生成模型&#xff0c;是 Stable Diffusion 系列的重要升级版本。以下是关于 Stable Diffusion 3.5 的详细信息&#xff1a; 版本概述 Stable Diffusion 3.5 包含三个主要版本&#xff1a; Stable Diffusion 3.5 L…

第23篇:Python开发进阶:详解测试驱动开发(TDD)

第23篇&#xff1a;测试驱动开发&#xff08;TDD&#xff09; 内容简介 在软件开发过程中&#xff0c;测试驱动开发&#xff08;TDD&#xff0c;Test-Driven Development&#xff09;是一种强调在编写实际代码之前先编写测试用例的开发方法。TDD不仅提高了代码的可靠性和可维…

【技术洞察】2024科技绘卷:浪潮、突破、未来

涌动与突破 2024年&#xff0c;科技的浪潮汹涌澎湃&#xff0c;人工智能、量子计算、脑机接口等前沿技术如同璀璨星辰&#xff0c;方便了大家的日常生活&#xff0c;也照亮了人类未来的道路。这一年&#xff0c;科技的突破与创新不断刷新着人们对未来的想象。那么回顾2024年的科…

新年快乐!!Market Moments 重磅更新!

2025 新年快乐&#xff01;在这个充满希望的新年里&#xff0c;愿大家都能心想事成&#xff0c;学业进步&#xff0c;健康快乐&#xff01; 随着新年的到来&#xff0c;Market Moments 也迎来重磅更新&#xff01; 首先&#xff0c;首页进行了改版升级&#xff0c;现在你可以在…