Transformer实现以及Pytorch源码解读(四)-多头注意力机制MultiheadAttention

news/2024/11/24 12:09:46/

介绍

接前序的三篇Transformer解读博客,补充说明第三次博客中MltiheadAttention类的数据源码处理

涉及到的源文件

\site-packages\torch\nn\modules\activation.py
\site-packages\torch\nn\functional.py

涉及到的函数

用到\site-packages\torch\nn\modules\activation.py的类:
MultiheadAttention类:
\site-packages\torch\nn\functional.py:
_in_projection_packed函数
_scaled_dot_product_attention函数
multi_head_attention_forward函数

数据流动过程

第一步:
根据以下的参数将MultiheadAttention类初始化。主要接受的参数,词向量维度,和头的数量

    def __init__(self, embed_dim, num_heads, dropout=0., bias=True, add_bias_kv=False, add_zero_attn=False,kdim=None, vdim=None, batch_first=False, device=None, dtype=None) -> None:factory_kwargs = {'device': device, 'dtype': dtype}super(MultiheadAttention, self).__init__()self.embed_dim = embed_dimself.kdim = kdim if kdim is not None else embed_dimself.vdim = vdim if vdim is not None else embed_dim#是一个binary变量,表示k,q,v的维度是否一样self._qkv_same_embed_dim = self.kdim == embed_dim and self.vdim == embed_dim# print("=========================_qkv_same_embed_dim=:",self._qkv_same_embed_dim)self.num_heads = num_headsself.dropout = dropoutself.batch_first = batch_firstself.head_dim = embed_dim // num_headsassert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"if self._qkv_same_embed_dim is False:self.q_proj_weight = Parameter(torch.empty((embed_dim, embed_dim), **factory_kwargs))self.k_proj_weight = Parameter(torch.empty((embed_dim, self.kdim), **factory_kwargs))self.v_proj_weight = Parameter(torch.empty((embed_dim, self.vdim), **factory_kwargs))self.register_parameter('in_proj_weight', None)else:self.in_proj_weight = Parameter(torch.empty((3 * embed_dim, embed_dim), **factory_kwargs))# print(self.in_proj_weight.shape)self.register_parameter('q_proj_weight', None)self.register_parameter('k_proj_weight', None)self.register_parameter('v_proj_weight', None)if bias:self.in_proj_bias = Parameter(torch.empty(3 * embed_dim, **factory_kwargs))else:self.register_parameter('in_proj_bias', None)self.out_proj = NonDynamicallyQuantizableLinear(embed_dim, embed_dim, bias=bias, **factory_kwargs)if add_bias_kv:self.bias_k = Parameter(torch.empty((1, 1, embed_dim), **factory_kwargs))self.bias_v = Parameter(torch.empty((1, 1, embed_dim), **factory_kwargs))else:self.bias_k = self.bias_v = Noneself.add_zero_attn = add_zero_attnself._reset_parameters()

其中in_proj_weight和in_proj_bias为初始化的权重和偏置项。通过参数_qkv_same_embed_dim判断是否为自注意力,如果是自注意力的话将q进行扩充3倍处理。
第二步
初始化权重和偏置项:
在_reset_parameters()函数中进行初始化,给权重和偏置项中的每个位置随机填充-a到a之间的一个数字,a的计算用到的以下的公式:
a=gain×6fan_in+fan_outa = \text{gain} \times \sqrt{\frac{6}{\text{fan\_in} + \text{fan\_out}}} a=gain×fan_in+fan_out6
以上公式的实现如下

def xavier_uniform_(tensor: Tensor, gain: float = 1.) -> Tensor:r"""Fills the input `Tensor` with values according to the methoddescribed in `Understanding the difficulty of training deep feedforwardneural networks` - Glorot, X. & Bengio, Y. (2010), using a uniformdistribution. The resulting tensor will have values sampled from:math:`\mathcal{U}(-a, a)` where.. math::a = \text{gain} \times \sqrt{\frac{6}{\text{fan\_in} + \text{fan\_out}}}Also known as Glorot initialization.Args:tensor: an n-dimensional `torch.Tensor`gain: an optional scaling factorExamples:>>> w = torch.empty(3, 5)>>> nn.init.xavier_uniform_(w, gain=nn.init.calculate_gain('relu'))"""fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor)std = gain * math.sqrt(2.0 / float(fan_in + fan_out))a = math.sqrt(3.0) * std  # Calculate uniform bounds from standard deviationreturn _no_grad_uniform_(tensor, -a, a)

fan_in 和fan_out的计算是根据输入tensor的维度确定的:

def _calculate_fan_in_and_fan_out(tensor):dimensions = tensor.dim()if dimensions < 2:raise ValueError("Fan in and fan out can not be computed for tensor with fewer than 2 dimensions")num_input_fmaps = tensor.size(1)num_output_fmaps = tensor.size(0)receptive_field_size = 1if tensor.dim() > 2:# math.prod is not always available, accumulate the product manually# we could use functools.reduce but that is not supported by TorchScriptfor s in tensor.shape[2:]:receptive_field_size *= sfan_in = num_input_fmaps * receptive_field_sizefan_out = num_output_fmaps * receptive_field_sizereturn fan_in, fan_out

第三步
在MultiheadAttention类的forward中进行每个batch的计算。
简化来看进行的是如下的操作:
(1)三个参数分别经过一个全连接层

 if not use_separate_proj_weight:#三个参数与in_porj_weight相乘。q, k, v = _in_projection_packed(query, key, value, in_proj_weight, in_proj_bias)
def _in_projection_packed(q: Tensor,k: Tensor,v: Tensor,w: Tensor,b: Optional[Tensor] = None,
) -> List[Tensor]:r"""Performs the in-projection step of the attention operation, using packed weights.Output is a triple containing projection tensors for query, key and value.Args:q, k, v: query, key and value tensors to be projected. For self-attention,these are typically the same tensor; for encoder-decoder attention,k and v are typically the same tensor. (We take advantage of theseidentities for performance if they are present.) Regardless, q, k and vmust share a common embedding dimension; otherwise their shapes may vary.w: projection weights for q, k and v, packed into a single tensor. Weightsare packed along dimension 0, in q, k, v order.b: optional projection biases for q, k and v, packed into a single tensorin q, k, v order.Shape:Inputs:- q: :math:`(..., E)` where E is the embedding dimension- k: :math:`(..., E)` where E is the embedding dimension- v: :math:`(..., E)` where E is the embedding dimension- w: :math:`(E * 3, E)` where E is the embedding dimension- b: :math:`E * 3` where E is the embedding dimensionOutput:- in output list :math:`[q', k', v']`, each output tensor will have thesame shape as the corresponding input tensor."""E = q.size(-1)if k is v:if q is k:# print("=========q:",q.shape)# print("=========w:",w.shape)return linear(q, w, b).chunk(3, dim=-1)else:# encoder-decoder attentionw_q, w_kv = w.split([E, E * 2])if b is None:b_q = b_kv = Noneelse:b_q, b_kv = b.split([E, E * 2])return (linear(q, w_q, b_q),) + linear(k, w_kv, b_kv).chunk(2, dim=-1)else:w_q, w_k, w_v = w.chunk(3)if b is None:b_q = b_k = b_v = Noneelse:b_q, b_k, b_v = b.chunk(3)return linear(q, w_q, b_q), linear(k, w_k, b_k), linear(v, w_v, b_v)

(2)三参数在bachsize维度根据头数扩充
多头就是在这里进行工作的

q = q.contiguous().view(tgt_len, bsz * num_heads, head_dim).transpose(0, 1)
k = k.contiguous().view(k.shape[0], bsz * num_heads, head_dim).transpose(0, 1)
v = v.contiguous().view(v.shape[0], bsz * num_heads, head_dim).transpose(0, 1)

(3)三参数顺序列相乘
获得注意力向量和最终结果。注意这里torch.bmm的使用。


def _scaled_dot_product_attention(q: Tensor,k: Tensor,v: Tensor,attn_mask: Optional[Tensor] = None,dropout_p: float = 0.0,
) -> Tuple[Tensor, Tensor]:r"""Computes scaled dot product attention on query, key and value tensors, usingan optional attention mask if passed, and applying dropout if a probabilitygreater than 0.0 is specified.Returns a tensor pair containing attended values and attention weights.Args:q, k, v: query, key and value tensors. See Shape section for shape details.attn_mask: optional tensor containing mask values to be added to calculatedattention. May be 2D or 3D; see Shape section for details.dropout_p: dropout probability. If greater than 0.0, dropout is applied.Shape:- q: :math:`(B, Nt, E)` where B is batch size, Nt is the target sequence length,and E is embedding dimension.- key: :math:`(B, Ns, E)` where B is batch size, Ns is the source sequence length,and E is embedding dimension.- value: :math:`(B, Ns, E)` where B is batch size, Ns is the source sequence length,and E is embedding dimension.- attn_mask: either a 3D tensor of shape :math:`(B, Nt, Ns)` or a 2D tensor ofshape :math:`(Nt, Ns)`.- Output: attention values have shape :math:`(B, Nt, E)`; attention weightshave shape :math:`(B, Nt, Ns)`"""B, Nt, E = q.shapeq = q / math.sqrt(E)# (B, Nt, E) x (B, E, Ns) -> (B, Nt, Ns)attn = torch.bmm(q, k.transpose(-2, -1))if attn_mask is not None:attn += attn_maskattn = softmax(attn, dim=-1)if dropout_p > 0.0:attn = dropout(attn, p=dropout_p)# (B, Nt, Ns) x (B, Ns, E) -> (B, Nt, E)output = torch.bmm(attn, v)return output, attn

总结

源码总对于num_head的处理有代码冗余的情况。


http://www.ppmy.cn/news/6178.html

相关文章

【LeetCode】在排序数组中查找元素的第一个和最后一个位置 [M](二分)

34. 在排序数组中查找元素的第一个和最后一个位置 - 力扣&#xff08;LeetCode&#xff09; 一、题目 给你一个按照非递减顺序排列的整数数组 nums&#xff0c;和一个目标值 target。请你找出给定目标值在数组中的开始位置和结束位置。 如果数组中不存在目标值 target&#x…

【Linux】缓冲区/磁盘inode/动静态库制作

目录 一、缓冲区 1、缓冲区的概念 2、缓冲区的意义 3、缓冲区刷新策略 4、同一份代码&#xff0c;打印结果不同 5、仿写FILE 5.1myFILE.h 5.2myFILE.c 5.3main.c 6、内核缓冲区 二、了解磁盘 1、磁盘的物理结构 2、磁盘的存储结构 2.1磁盘的定位 3、磁盘的抽象…

Nginx服务讲解

Nginx服务讲解 1、同步与异步讲解 同步与异步&#xff1a; 同步与异步的重点在消息通知的方式上&#xff0c;也就是调用结果的通知方式不同。 **同步&#xff1a;**当一个同步调用发出去后&#xff0c;调用者要一直等待调用的结果通知后&#xff0c;才能进行后续的执行。 …

10.1、Django框架入门--后台管理

文章目录预备知识MVC模式和MTV模式MVC模式MTV 模式Django框架Django框架简介Django框架中的后台管理启动后台admin站点管理数据库迁移创建管理员用户管理界面本地化创建并使用一个应用bookapp项目的数据库模型创建数据库模型生成数据库表数据库上的基本操作启用后台admin站点管…

“理想家”冬日献礼福利来袭,20 款豪礼限时梦想值回赠

悄然间冬天已至&#xff0c;为了让广大客户在冬日里感受到来自 Doo Prime 的温暖问候&#xff0c;”理想家”积分商城现推出“冬日献礼梦想值回赠”活动&#xff0c;希望治愈您的冬日倦意&#xff0c;带给您冬日惊喜。 我们特意甄选出 20 款暖冬好物&#xff0c;包括 YSL 圣诞…

2023年5大网络安全趋势加速发展

©网络研究院 Netwrix发布了2023年将影响各种规模组织的关键网络安全趋势。以下是你需要注意的五个具体趋势: 网络犯罪的业务将进一步专业化 Emotet、Conti和Trickbot等恶意软件的回归表明网络雇佣犯罪的扩张。特别是&#xff0c;勒索软件即服务的增长使没有深厚技术技能…

WMS类图分析-android12

为什么要分析类图&#xff1f; WMS是一个复杂的模块&#xff0c;就像一个很大的家族&#xff0c;里面有各种角色&#xff0c;认识类图就像是认识WMS模块中的各个角色&#xff0c;不先把人认清楚了&#xff0c;怎么更好的理解他们之间的交互&#xff1f; 我觉得&#xff0c;这…

漏洞预警|Apache Karaf 存在远程代码执行漏洞

棱镜七彩安全预警 近日网上有关于开源项目 Apache Karaf 存在远程代码执行漏洞&#xff0c;棱镜七彩威胁情报团队第一时间探测到&#xff0c;经分析研判&#xff0c;向全社会发起开源漏洞预警公告&#xff0c;提醒相关安全团队及时响应。 项目介绍 Karaf是Apache旗下的一个开…