【自然语言处理(NLP)】多头注意力(Multi - Head Attention)原理及代码实现

news/2025/2/3 16:06:40/

文章目录

  • 介绍
  • 多头注意力
    • 原理
    • 代码实现
      • 导包
      • 多头注意力结构
      • qkv转换
      • output转换
      • 构建注意力模块
      • 添加Bahdanau的decoder
      • 训练
      • 预测

个人主页:道友老李
欢迎加入社区:道友老李的学习社区

介绍

**自然语言处理(Natural Language Processing,NLP)**是计算机科学领域与人工智能领域中的一个重要方向。它研究的是人类(自然)语言与计算机之间的交互。NLP的目标是让计算机能够理解、解析、生成人类语言,并且能够以有意义的方式回应和操作这些信息。

NLP的任务可以分为多个层次,包括但不限于:

  1. 词法分析:将文本分解成单词或标记(token),并识别它们的词性(如名词、动词等)。
  2. 句法分析:分析句子结构,理解句子中词语的关系,比如主语、谓语、宾语等。
  3. 语义分析:试图理解句子的实际含义,超越字面意义,捕捉隐含的信息。
  4. 语用分析:考虑上下文和对话背景,理解话语在特定情境下的使用目的。
  5. 情感分析:检测文本中表达的情感倾向,例如正面、负面或中立。
  6. 机器翻译:将一种自然语言转换为另一种自然语言。
  7. 问答系统:构建可以回答用户问题的系统。
  8. 文本摘要:从大量文本中提取关键信息,生成简短的摘要。
  9. 命名实体识别(NER):识别文本中提到的特定实体,如人名、地名、组织名等。
  10. 语音识别:将人类的语音转换为计算机可读的文字格式。

NLP技术的发展依赖于算法的进步、计算能力的提升以及大规模标注数据集的可用性。近年来,深度学习方法,特别是基于神经网络的语言模型,如BERT、GPT系列等,在许多NLP任务上取得了显著的成功。随着技术的进步,NLP正在被应用到越来越多的领域,包括客户服务、智能搜索、内容推荐、医疗健康等。

多头注意力

原理

多头注意力机制(Multi - Head Attention)的结构示意图:
在这里插入图片描述

多头注意力机制首先将查询(Query)、键(Key)、值(Value)分别通过多个全连接层进行线性变换,得到多个不同的表示。然后,对这些不同的表示分别进行注意力计算。最后,将各个注意力的结果进行连结(Concatenate),再通过一个全连接层得到最终输出。

这种机制允许模型在不同的表示子空间中并行地关注输入序列的不同部分,能够捕捉到更丰富的语义信息,广泛应用于Transformer等模型架构中,在自然语言处理、计算机视觉等领域有重要应用。

模型计算方式:
在这里插入图片描述
在该表达式中, h i h_i hi 是注意力机制计算得到的输出, f f f 一般表示注意力计算函数(如缩放点积注意力等), W q i W_q^i Wqi W k i W_k^i Wki W v i W_v^i Wvi 分别是针对查询(query)、键(key)、值(value)的可学习权重矩阵, q q q k k k v v v 分别为查询向量、键向量、值向量 , R n \mathbb{R}^n Rn 表示输出 h i h_i hi 处于 n n n 维实数空间。它表达了在注意力计算中,通过对查询、键、值进行线性变换后再经过注意力计算函数得到输出的过程。

矩阵运算表达式:
在这里插入图片描述

表达式中 [ h 1 ⋮ h n ] \begin{bmatrix}h_1\\ \vdots \\h_n\end{bmatrix} h1hn 是一个由 h 1 h_1 h1 h n h_n hn 构成的列向量,这些 h i h_i hi 通常可以是注意力机制等模块的输出。 W o W_o Wo 是一个可学习的权重矩阵,其维度为 R p × n \mathbb{R}^{p\times n} Rp×n ,这里 p p p 是输出维度相关参数, n n n 是输入向量的长度(即 h i h_i hi 的数量)。该表达式表示对由 h i h_i hi 组成的向量进行线性变换,常用于深度学习模型(如Transformer等)的后处理阶段,对前面模块输出进行进一步的特征变换或整合。

代码实现

导包

import math
import torch
from torch import nn
import dltools

多头注意力结构

class MultiHeadAttention(nn.Module):def __init__(self, key_size, query_size, value_size, num_hiddens, num_heads, dropout, bias=False, **kwargs):super().__init__(**kwargs)self.num_heads = num_headsself.attention = dltools.DotProductAttention(dropout)self.W_q = nn.Linear(query_size, num_hiddens, bias=bias)self.W_k = nn.Linear(key_size, num_hiddens, bias=bias)self.W_v = nn.Linear(value_size, num_hiddens, bias=bias)self.W_o = nn.Linear(num_hiddens, num_hiddens, bias=bias)def forward(self, queries, keys, values, valid_lens):# queries, keys, values 传入的形状: (batch_size, 查询熟练或者键值对数量, num_hiddens)queries = transpose_qkv(self.W_q(queries), self.num_heads)keys = transpose_qkv(self.W_k(keys), self.num_heads)values = transpose_qkv(self.W_v(values), self.num_heads)
#         print('queries:', queries.shape)
#         print('keys:', keys.shape)
#         print('values:', values.shape)if valid_lens is not None:valid_lens = torch.repeat_interleave(valid_lens, repeats=self.num_heads, dim=0)# output shape: (batch_size * num_heads, 查询的个数, num_hiddens/num_heads)output = self.attention(queries, keys, values, valid_lens)
#         print('output:', output.shape)output_concat = transpose_output(output, self.num_heads)
#         print('output_concat:', output_concat.shape)return self.W_o(output_concat)

qkv转换

def transpose_qkv(X, num_heads):# 输入X的shape: (batch_size, 查询数/键值对数, num_hiddens)X = X.reshape(X.shape[0], X.shape[1], num_heads, -1)X = X.permute(0, 2, 1, 3) # batch_size, num_heads, 查询数/ 键值对数, num_hiddens/num_heads# 这里是把batch_size和num_heads合并在一起了. return X.reshape(-1, X.shape[2], X.shape[3]) # batch_size * num_heads, 查询/键值对数, num_hiddens/ num_heads

output转换

def transpose_output(X, num_heads):# 逆转transpose_qkv的操作X = X.reshape(-1, num_heads, X.shape[1], X.shape[2])X = X.permute(0, 2, 1, 3)return X.reshape(X.shape[0], X.shape[1], -1)

构建注意力模块

num_hiddens, num_heads = 100, 5
attention = MultiHeadAttention(num_hiddens, num_hiddens, num_hiddens, num_hiddens, num_heads, 0.2)
attention.eval()

在这里插入图片描述

添加Bahdanau的decoder

class Seq2SeqMultiHeadAttentionDecoder(dltools.AttentionDecoder):def __init__(self, vocab_size, embed_size, num_hiddens, num_heads, num_layers, dropout=0, **kwargs):super().__init__(**kwargs)self.attention = MultiHeadAttention(num_hiddens, num_hiddens, num_hiddens, num_hiddens, num_heads, dropout)self.embedding = nn.Embedding(vocab_size, embed_size)self.rnn = nn.GRU(embed_size + num_hiddens, num_hiddens, num_layers, dropout=dropout)self.dense = nn.Linear(num_hiddens, vocab_size)def init_state(self, enc_outputs, enc_valid_lens, *args):# outputs : (batch_size, num_steps, num_hiddens)# hidden_state: (num_layers, batch_size, num_hiddens)outputs, hidden_state = enc_outputsreturn (outputs.permute(1, 0, 2), hidden_state, enc_valid_lens)def forward(self, X, state):# enc_outputs (batch_size, num_steps, num_hiddens)# hidden_state: (num_layers, batch_size, num_hiddens)enc_outputs, hidden_state, enc_valid_lens = state# X : (batch_size, num_steps, vocab_size)X = self.embedding(X) # X : (batch_size, num_steps, embed_size)X = X.permute(1, 0, 2)outputs, self._attention_weights = [], []for x in X:query = torch.unsqueeze(hidden_state[-1], dim=1) # batch_size, 1, num_hiddens
#             print('query:', query.shape) # 4, 1, 16context = self.attention(query, enc_outputs, enc_outputs, enc_valid_lens)
#             print('context: ', context.shape)x = torch.cat((context, torch.unsqueeze(x, dim=1)), dim=-1)
#             print('x: ', x.shape)out, hidden_state = self.rnn(x.permute(1, 0, 2), hidden_state)
#             print('out:', out.shape)
#             print('hidden_state:', hidden_state.shape)outputs.append(out)self._attention_weights.append(self.attention_weights)#         print('---------------------------------')outputs = self.dense(torch.cat(outputs, dim=0))
#         print('解码器最终输出形状: ', outputs.shape)return outputs.permute(1, 0, 2), [enc_outputs, hidden_state, enc_valid_lens]@propertydef attention_weights(self):return self._attention_weights

训练

embed_size, num_hiddens, num_layers, dropout = 32, 100, 2, 0.1
batch_size, num_steps, num_heads = 64, 10, 5
lr, num_epochs, device = 0.005, 200, dltools.try_gpu()train_iter, src_vocab, tgt_vocab = dltools.load_data_nmt(batch_size, num_steps)
encoder = dltools.Seq2SeqEncoder(len(src_vocab), embed_size, num_hiddens, num_layers, dropout)
decoder = Seq2SeqMultiHeadAttentionDecoder(len(tgt_vocab), embed_size, num_hiddens, num_heads, num_layers, dropout)
net = dltools.EncoderDecoder(encoder, decoder)
dltools.train_seq2seq(net, train_iter, lr, num_epochs, tgt_vocab, device)

在这里插入图片描述
在这里插入图片描述

在这里插入图片描述

在这里插入图片描述
在这里插入图片描述

预测

engs = ['go .', 'i lost .', 'he\'s calm .', 'i\'m home .']
fras = ['va !', 'j\'ai perdu .', 'il est calme .', 'je suis chez moi .']
for eng, fra in zip(engs, fras):translation = dltools.predict_seq2seq(net, eng, src_vocab, tgt_vocab, num_steps, device)print(f'{eng} => {translation}, bleu {dltools.bleu(translation[0], fra, k=2):.3f}')
go . => ('va !', []), bleu 1.000
i lost . => ("j'ai perdu .", []), bleu 1.000
he's calm . => ('il est paresseux .', []), bleu 0.658
i'm home . => ('je suis chez moi .', []), bleu 1.000

http://www.ppmy.cn/news/1568992.html

相关文章

计算机网络 性能指标相关

目录 吞吐量 时延 时延带宽积 往返时延RTT 利用率 吞吐量 时延 时延带宽积 往返时延RTT 利用率

使用QSqlQueryModel创建交替背景色的表格模型

class UserModel(QSqlQueryModel):def __init__(self):super().__init__()self._query "SELECT name, age FROM users"self.refresh()def refresh(self):self.setQuery(self._query)# 重新定义data()方法def data(self, index, role): if role Qt.BackgroundRole…

SpringCloudGateWay和Sentinel结合做黑白名单来源控制

假设我们的分布式项目,admin是8087,gateway是8088,consumer是8086 我们一般的思路是我们的请求必须经过我们的网关8088然后网关转发到我们的分布式项目,那我要是没有处理我们绕过网关直接访问项目8087和8086不也是可以&#xff1…

Selenium 浏览器操作与使用技巧——详细解析(Java版)

目录 一、浏览器及窗口操作 二、键盘与鼠标操作 三、勾选复选框 四、多层框架/窗口定位 五、操作下拉框 六、上传文件操作 七、处理弹窗与 alert 八、处理动态元素 九、使用 Selenium 进行网站监控 前言 Selenium 是一款非常强大的 Web 自动化测试工具,能够…

AAPM:基于大型语言模型代理的资产定价模型,夏普比率提高9.6%

“AAPM: Large Language Model Agent-based Asset Pricing Models” 论文地址:https://arxiv.org/pdf/2409.17266v1 Github地址:https://github.com/chengjunyan1/AAPM 摘要 这篇文章介绍了一种利用LLM代理的资产定价模型(AAPM)…

【C语言】填空题/程序填空题1

1. 下列程序取出一个整数x的二进制表示中,从第p位开始的n位二进制,并输出所表示的整数值。如: 输入:-17 5 3 输出:5 【说明】整数-17的32位二进制表示为:11111111 11111111 11111111 11101111,…

wx043基于springboot+vue+uniapp的智慧物流小程序

开发语言:Java框架:springbootuniappJDK版本:JDK1.8服务器:tomcat7数据库:mysql 5.7(一定要5.7版本)数据库工具:Navicat11开发软件:eclipse/myeclipse/ideaMaven包&#…

康德哲学与自组织思想的渊源:从《判断力批判》到系统论的桥梁

康德哲学与自组织思想的渊源:从《判断力批判》到系统论的桥梁 第一节:康德哲学中的自然目的论与自组织思想 核心内容: 康德哲学中的自然目的论和反思判断力概念,为现代系统论中的自组织思想提供了哲学基础,预见了复…