【llm对话系统】大模型源码分析之 LLaMA 模型的 Masked Attention

server/2025/2/4 4:38:41/

在大型语言模型(LLM)中,注意力机制(Attention Mechanism)是核心组成部分。然而,在自回归(autoregressive)模型中,例如 LLaMA,我们需要对注意力进行屏蔽(Masking),以防止模型“偷看”未来的信息。本文将深入探讨 LLaMA 模型中 Masked Attention 的实现逻辑,并对比其他类型大模型中常用的 Masked Attention 方案。

1. 什么是 Masked Attention

1.1 为什么需要 Mask

在自回归模型中,模型的目标是根据已有的输入序列预测下一个词。在训练阶段,模型会接收整个输入序列,但在预测某个位置的词时,模型不应该看到该位置之后的信息。这就是 Masked Attention 的作用:它会屏蔽未来词对当前词的影响,确保模型只能依赖于过去的信息进行预测。

1.2 Mask 的类型

Mask 主要分为两种类型:

  1. Padding Mask: 用于处理变长序列,屏蔽 padding 部分对注意力计算的影响。
  2. Causal Mask: 用于自回归模型,屏蔽未来位置的信息,防止模型偷看未来。

2. LLaMA 中的 Masked Attention

LLaMA 模型主要关注自回归的生成任务,所以使用的是 Causal Mask

2.1 LLaMA 的实现逻辑

LLaMA 使用标准的多头自注意力机制(Multi-Head Self-Attention, MHA),并在计算注意力权重时应用 Causal Mask。具体流程如下:

  1. 线性变换: 将输入序列映射为查询(Query)、键(Key)和值(Value)向量。
  2. 计算注意力分数: 计算查询向量和键向量的点积,并进行缩放。
  3. 应用 Mask: 使用 Causal Mask 屏蔽未来位置的注意力分数。
  4. 计算注意力权重: 对屏蔽后的注意力分数进行 Softmax 归一化。
  5. 计算加权值向量: 使用注意力权重对值向量进行加权求和。

2.2 LLaMA 源码示例 (PyTorch)

以下是 LLaMA 模型中 Masked Attention 的核心代码(简化版):

python">import torch
import torch.nn as nn
import torch.nn.functional as Fclass LlamaAttention(nn.Module):def __init__(self, d_model, num_heads):super().__init__()self.d_model = d_modelself.num_heads = num_headsself.head_dim = d_model // num_heads# 线性变换self.Wq = nn.Linear(d_model, d_model)self.Wk = nn.Linear(d_model, d_model)self.Wv = nn.Linear(d_model, d_model)self.Wo = nn.Linear(d_model, d_model)def forward(self, x, mask=None):batch_size, seq_len, _ = x.size()# 线性变换q = self.Wq(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)  # (batch_size, num_heads, seq_len, head_dim)k = self.Wk(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)v = self.Wv(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)# 计算注意力分数scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim)# 应用 Maskif mask is not None:scores = scores.masked_fill(mask==0, float('-inf'))# 计算注意力权重attn_weights = F.softmax(scores, dim=-1)# 计算加权值向量attn_output = torch.matmul(attn_weights, v).transpose(1, 2).contiguous().view(batch_size, seq_len, self.d_model)# 输出线性变换output = self.Wo(attn_output)return outputdef generate_causal_mask(seq_len):mask = torch.ones((seq_len, seq_len), dtype=torch.bool).triu(diagonal=1)return mask.bool()# 示例
import math
d_model = 512
num_heads = 8
seq_len = 10
batch_size = 2attention_layer = LlamaAttention(d_model, num_heads)
input_tensor = torch.randn(batch_size, seq_len, d_model)
causal_mask = generate_causal_mask(seq_len)output = attention_layer(input_tensor, mask=causal_mask)
print("Output Shape:", output.shape) # 输出: torch.Size([2, 10, 512])

代码解释:

  1. LlamaAttention 类:
    • 初始化线性变换层 Wq, Wk, Wv, 和 Wo
    • forward 方法中,首先对输入进行线性变换,并进行多头分割。
    • 计算注意力分数,将 qk 进行点积运算,并进行缩放。
    • 应用 Causal Mask。
    • 使用 Softmax 对分数进行归一化,得到注意力权重。
    • 使用注意力权重对 v 进行加权求和。
  2. generate_causal_mask 函数:
    • 生成一个下三角矩阵,其中对角线以上的位置为 True(即需要 Mask 的位置)。
    • 将 Mask 返回为布尔类型,方便后续使用 masked_fill 函数进行填充。
  3. 示例:
    • 使用随机的输入张量,构造 causal_mask。
    • 调用注意力层,得到输出。

2.3 Mask 应用细节

在代码中,我们使用了 scores.masked_fill(mask == 0, float('-inf')) 来应用 Mask。masked_fill 函数会将 mask 中为 False (也就是需要mask的位置) 的位置填充为 -inf。在 Softmax 计算时, -inf 将会被转换为 0,从而有效地屏蔽了未来的信息。

3. 与其他大模型 Masked Attention 方案的对比

3.1 GPT 系列模型

GPT 系列模型也使用 Causal Mask,其实现方式与 LLaMA 类似。主要区别在于:

  • 实现方式: GPT 系列模型通常使用 torch.triu() 函数来生成上三角 Mask,然后使用 masked_fill 函数填充。
  • 结构: GPT 模型主要使用单向 Transformer 结构,而 LLaMA 模型使用双向 Transformer 结构(encoder-decoder 结构)。

3.2 BERT 系列模型

BERT 系列模型主要用于理解任务,使用了双向注意力机制。BERT 使用两种 Mask:

  • Padding Mask: 用于处理变长序列,屏蔽 padding 部分的影响。
  • Attention Mask (随机mask): 在预训练阶段,随机 mask 输入序列中的一部分词,让模型预测被屏蔽的词。

3.3 对比总结

模型Mask 类型Mask 实现方式适用场景
LLaMACausal Maskmasked_fill自回归生成
GPT 系列Causal Masktorch.triu() + masked_fill自回归生成
BERT 系列Padding Mask & Attention Maskmasked_fill理解任务

4. 训练与推理时的 Mask

4.1 训练时

在训练阶段,我们会为每个输入序列都生成相应的 Causal Mask。Mask 的形状取决于输入序列的长度,确保模型只能看到当前位置之前的输入。

4.2 推理时

在推理阶段(生成文本时),我们需要动态更新 Mask。每生成一个新词,我们都会追加到当前序列,并为新的序列生成相应的 Causal Mask。LLaMA 模型为了提升推理效率,做了很多优化,比如KV Cache,增量式的更新mask,加速推理。

5. 总结

本文深入分析了 LLaMA 模型中 Masked Attention 的实现逻辑,并对比了其他类型大模型的 Masked Attention 方案。通过了解 Mask 的原理和具体实现,我们能更好地理解自回归模型的工作方式。希望本文能帮助你更好地理解大模型中的注意力机制!

6. 参考资料

  • Attention is All You Need
  • Transformer Language Models

http://www.ppmy.cn/server/164803.html

相关文章

跨境数据传输问题常见解决方式

在全球化经济的浪潮下,跨境数据传输已然成为企业日常运营的关键环节。随着数字贸易的蓬勃发展和跨国业务的持续扩张,企业在跨境数据处理方面遭遇了诸多棘手难题。那么,面对这些常见问题,企业该如何应对?镭速跨境数据传…

初级数据结构:栈和队列

目录 一、栈 (一)、栈的定义 (二)、栈的功能 (三)、栈的实现 1.栈的初始化 2.动态扩容 3.压栈操作 4.出栈操作 5.获取栈顶元素 6.获取栈顶元素的有效个数 7.检查栈是否为空 8.栈的销毁 9.完整代码 二、队列 (一)、队列的定义 (二)、队列的功能 (三&#xff09…

队列—学习

1. 手写队列的实现 使用数组实现队列是一种常见的方法。队列的基本操作包括入队(enqueue)和出队(dequeue)。队列的头部和尾部分别用 head 和 tail 指针表示。 代码实现 const int N 10000; // 定义队列容量,确保够…

为什么LabVIEW适合软硬件结合的项目?

LabVIEW是一种基于图形化编程的开发平台,广泛应用于软硬件结合的项目中。其强大的硬件接口支持、实时数据采集能力、并行处理能力和直观的用户界面,使得它成为工业控制、仪器仪表、自动化测试等领域中软硬件系统集成的理想选择。LabVIEW的设计哲学强调模…

K个不同子数组的数目--滑动窗口--字节--亚马逊

Stay hungry, stay foolish 题目描述 给定一个正整数数组 nums和一个整数 k,返回 nums 中 「好子数组」 的数目。 如果 nums 的某个子数组中不同整数的个数恰好为 k,则称 nums 的这个连续、不一定不同的子数组为 「好子数组 」。 例如,[1,2,…

使用PaddlePaddle实现逻辑回归:从训练到模型保存与加载

1. 引入必要的库 首先,需要引入必要的库。PaddlePaddle用于构建和训练模型,pandas和numpy用于数据处理,matplotlib用于结果的可视化。 import paddle import pandas as pd import numpy as np import matplotlib.pyplot as plt 2. 加载自定…

【算法学习笔记】36:中国剩余定理(Chinese Remainder Theorem)求解线性同余方程组

中国剩余定理 假定存在 m 1 . . m k m_1..m_k m1​..mk​两两互质,中国剩余定理旨在求解这样的线性同余方程组中的 x x x: x ≡ a 1 ( m o d m 1 ) x ≡ a 2 ( m o d m 2 ) . . . x ≡ a k ( m o d m k ) x \equiv a_1~(mod~m_1) \\ x \equiv a_2~(mod…

【面经】字节南京一面部分题目记录

南京字节一面题,可能因为项目不太匹配,全程八股比较多,也有两道手撕代码题,强度还是有的。为了方便大家学习,大部分答案由GPT整理,有些题给出了我认为回答比较好的博客链接。 文章目录 一、python2 和 pyth…