BERT模型中的多头注意力机制详解

embedded/2024/11/20 0:17:26/
摘要

深度学习领域,特别是自然语言处理(NLP)中,Transformer模型因其卓越的性能而广受关注。其中,多头注意力机制是Transformer模型的核心组成部分之一。本文将深入探讨BERT模型中多头注意力机制的实现细节,帮助读者更好地理解和应用这一关键技术。

1. 引言

BERT(Bidirectional Encoder Representations from Transformers)是一种基于Transformer架构的预训练模型。Transformer模型的核心在于其多头注意力机制,该机制允许模型在处理序列数据时同时关注多个位置的信息,从而提高了模型的表达能力和泛化能力。本文将详细介绍BERT模型中多头注意力机制的实现。

2. 多头注意力机制概述

多头注意力机制的基本思想是将输入张量投影到多个不同的子空间中,在每个子空间中独立计算注意力权重,然后将这些子空间的结果合并起来。这种机制使得模型能够在不同的抽象层次上捕获信息,从而提高了模型的性能。

3. 函数定义
def attention_layer(from_tensor,to_tensor,attention_mask=None,num_attention_heads=1,size_per_head=512,query_act=None,key_act=None,value_act=None,attention_probs_dropout_prob=0.0,initializer_range=0.02,do_return_2d_tensor=False,batch_size=None,from_seq_length=None,to_seq_length=None):"""Performs multi-headed attention from `from_tensor` to `to_tensor`.Args:from_tensor: float Tensor of shape [batch_size, from_seq_length, from_width].to_tensor: float Tensor of shape [batch_size, to_seq_length, to_width].attention_mask: (optional) int32 Tensor of shape [batch_size, from_seq_length, to_seq_length].num_attention_heads: int. Number of attention heads.size_per_head: int. Size of each attention head.query_act: (optional) Activation function for the query transform.key_act: (optional) Activation function for the key transform.value_act: (optional) Activation function for the value transform.attention_probs_dropout_prob: (optional) float. Dropout probability of the attention probabilities.initializer_range: float. Range of the weight initializer.do_return_2d_tensor: bool. If True, the output will be of shape [batch_size * from_seq_length, num_attention_heads * size_per_head].batch_size: (Optional) int. If the input is 2D, this might be the batch size of the 3D version of the `from_tensor` and `to_tensor`.from_seq_length: (Optional) If the input is 2D, this might be the seq length of the 3D version of the `from_tensor`.to_seq_length: (Optional) If the input is 2D, this might be the seq length of the 3D version of the `to_tensor`.Returns:float Tensor of shape [batch_size, from_seq_length, num_attention_heads * size_per_head]."""...
4. 实现细节
4.1 输入张量形状检查

函数首先检查输入张量 from_tensorto_tensor 的形状是否符合预期。这两个张量的形状应该是 [batch_size, seq_length, hidden_size]

4.2 投影到查询、键和值张量
  1. 查询张量:将 from_tensor 投影到查询张量 query_layer
  2. 键张量:将 to_tensor 投影到键张量 key_layer
  3. 值张量:将 to_tensor 投影到值张量 value_layer
query_layer = tf.layers.dense(from_tensor_2d,num_attention_heads * size_per_head,activation=query_act,name="query",kernel_initializer=create_initializer(initializer_range))key_layer = tf.layers.dense(to_tensor_2d,num_attention_heads * size_per_head,activation=key_act,name="key",kernel_initializer=create_initializer(initializer_range))value_layer = tf.layers.dense(to_tensor_2d,num_attention_heads * size_per_head,activation=value_act,name="value",kernel_initializer=create_initializer(initializer_range))
4.3 转置张量以适应多头注意力

为了适应多头注意力机制,需要将查询、键和值张量转置为 [batch_size, num_attention_heads, seq_length, size_per_head] 的形状。

query_layer = transpose_for_scores(query_layer, batch_size, num_attention_heads, from_seq_length, size_per_head)
key_layer = transpose_for_scores(key_layer, batch_size, num_attention_heads, to_seq_length, size_per_head)
4.4 计算注意力分数

通过矩阵乘法计算查询张量和键张量之间的点积,得到原始的注意力分数。然后,将这些分数除以 sqrt(size_per_head) 进行缩放。

attention_scores = tf.matmul(query_layer, key_layer, transpose_b=True)
attention_scores = tf.multiply(attention_scores, 1.0 / math.sqrt(float(size_per_head)))
4.5 应用注意力掩码

如果提供了 attention_mask,则将其扩展为 [batch_size, 1, from_seq_length, to_seq_length] 的形状,并将其应用于注意力分数,以屏蔽不需要关注的位置。

if attention_mask is not None:attention_mask = tf.expand_dims(attention_mask, axis=[1])adder = (1.0 - tf.cast(attention_mask, tf.float32)) * -10000.0attention_scores += adder
4.6 归一化注意力分数

使用 softmax 函数将注意力分数归一化为概率分布。

attention_probs = tf.nn.softmax(attention_scores)
4.7 应用dropout

为了防止过拟合,可以在注意力概率上应用dropout。

attention_probs = dropout(attention_probs, attention_probs_dropout_prob)
4.8 计算上下文向量

通过将注意力概率与值张量相乘,得到上下文向量。

context_layer = tf.matmul(attention_probs, value_layer)
4.9 转置并重塑上下文向量

最后,将上下文向量转置并重塑为 [batch_size, from_seq_length, num_attention_heads * size_per_head] 的形状。

context_layer = tf.transpose(context_layer, [0, 2, 1, 3])
if do_return_2d_tensor:context_layer = tf.reshape(context_layer, [batch_size * from_seq_length, num_attention_heads * size_per_head])
else:context_layer = tf.reshape(context_layer, [batch_size, from_seq_length, num_attention_heads * size_per_head])
5. 应用示例

假设我们有一个输入张量 from_tensor 和一个目标张量 to_tensor,以及一个注意力掩码 attention_mask,我们可以使用上述函数进行多头注意力计算:

import tensorflow as tf# 假设的输入张量和掩码
from_tensor = tf.random.uniform([2, 10, 128])
to_tensor = tf.random.uniform([2, 10, 128])
attention_mask = tf.constant([[1, 1, 1, 0, 0, 0, 0, 0, 0, 0],[1, 1, 1, 1, 1, 0, 0, 0, 0, 0]], dtype=tf.int32)# 多头注意力计算
context_layer = attention_layer(from_tensor=from_tensor,to_tensor=to_tensor,attention_mask=attention_mask,num_attention_heads=8,size_per_head=16,attention_probs_dropout_prob=0.1
)with tf.Session() as sess:sess.run(tf.global_variables_initializer())context_layer_val = sess.run(context_layer)print("Context Layer Shape:", context_layer_val.shape)
6. 结论

本文详细介绍了BERT模型中的多头注意力机制的实现。通过这一机制,模型能够在不同的抽象层次上捕获信息,从而提高了模型的表达能力和泛化能力。希望本文能为读者在自然语言处理领域的研究和开发提供有益的参考。


http://www.ppmy.cn/embedded/138910.html

相关文章

Redis 高并发缓存架构实战与性能优化

前言 在高并发场景下,同时操作数据库和缓存会存在数据不一致性的问题。这常常在面试时,面试官很喜欢问的一个问题,你们系统有用 Redis?使用Redis实现了哪些业务场景?如何保证数据的一致性? 问题 总体归纳…

正态分布密度函数的基本概念

概率论中的正态分布密度函数是统计学和数据分析中的一个核心概念,而MATLAB作为一种强大的数学计算软件,为处理和分析正态分布数据提供了丰富的工具和函数。以下是对正态分布密度函数及其在MATLAB中的应用的详细探讨。 一、正态分布密度函数的基本概念 …

【苍穹外卖】学习日志-day1

目录 nginx 反向代理介绍 nginx 的优势 提高访问速度 负载均衡 保证后端服务安全 高并发静态资源 Swagger 生成 API 文档 Swagger 的使用方式 导入knife4j的maven坐标 在配置类中加入knife4j相关配置 设置静态资源映射 通过注解控制生成的接口文档 项目技术点 Token 模式 MD5 加…

docker 部署freeswitch(非编译方式)

一:安装部署 1.拉取镜像 参考:https://hub.docker.com/r/safarov/freeswitch docker pull safarov/freeswitch 2.启动镜像 docker run --nethost --name freeswitch \-e SOUND_RATES8000:16000 \-e SOUND_TYPESmusic:en-us-callie \-v /home/xx/f…

51单片机基础01 单片机最小系统

目录 一、什么是51单片机 二、51单片机的引脚介绍 1、VCC GND 2、XTAL1 2 3、RST 4、EA 5、PSEN 6、ALE 7、RXD、TXD 8、INT0、INT1 9、T0、T1 10、MOSI、MISO、SCK 11、WR、RD 12、通用IO P0 13、通用IO P1 14、通用IO P2 三、51单片机的最小系统 1、供电与…

小林Coding—Java「五、Java虚拟机面试篇」

五、Java虚拟机面试篇(难⭐️⭐️⭐️) 内存模型 JVM的内存模型介绍一下 JVM运行时内存共分为虚拟机栈、堆、元空间、程序计数器、本地方法栈 五个部分。还有一部分内存叫直接内存,属于操作系统的本地内存,也是可以直接操作的。 …

thinkphp6 入门(2)--视图、渲染html页面、赋值

use think\facade\View;View::assign([name > ThinkPHP,email > thinkphpqq.com]);View::assign(data,[name > ThinkPHP,email > thinkphpqq.com]); View::fetch(index);助手函数 view(index, [name > ThinkPHP,email > thinkphpqq.com ]); 模板输出 {$na…

集群聊天服务器(12)nginx负载均衡器

目录 负载均衡器nginx负载均衡器优势 如何解决集群聊天服务器跨服务器通信问题?nginx的TCP负载均衡配置nginx配置 负载均衡器 目前最多只能支持2w台客户机进行同时聊天 所以要引入集群,多服务器。 但是客户连哪一台服务器呢?客户并不知道哪一…