【llm对话系统】大模型源码分析之llama模型的long context更长上下文支持

news/2025/2/7 0:05:19/

1. 引言

Llama模型的一个重要特性是支持长上下文处理。本文将深入分析Llama源码中实现长上下文的关键技术点,包括位置编码(position embedding)的外推方法、注意力机制的优化等。我们将通过详细的代码解析来理解其实现原理。

2. 位置编码的外推实现

2.1 旋转位置编码(RoPE)基础

Llama采用旋转位置编码(RoPE, Rotary Position Embedding)来编码token的位置信息。RoPE的实现包含几个关键步骤:

def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0, scale: float = 1.0):"""预计算RoPE的频率Args:dim: 隐藏层维度end: 序列最大长度theta: RoPE的基频参数scale: 位置缩放因子Returns:freqs_cis: 复数形式的频率矩阵"""# 生成维度序列 [0, 2, ..., dim-2]dims = torch.arange(0, dim, 2)[: (dim // 2)].float()# 计算频率基数 1/θ^(2i/d)freqs = 1.0 / (theta ** (dims / dim))# 生成位置序列并应用缩放t = torch.arange(end, device=freqs.device) * scale# 计算位置和频率的外积freqs = torch.outer(t, freqs)# 转换为复数形式 e^(iθ)freqs_cis = torch.polar(torch.ones_like(freqs), freqs)return freqs_cisdef apply_rotary_emb(xq: torch.Tensor,xk: torch.Tensor,freqs_cis: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:"""应用旋转位置编码Args:xq: query张量 [batch_size, seq_len, num_heads, head_dim]xk: key张量 [batch_size, seq_len, num_heads, head_dim]freqs_cis: 预计算的频率 [seq_len, head_dim//2]"""# 重塑张量以方便运算xq_r, xq_i = xq.float().reshape(*xq.shape[:-1], -1, 2).unbind(-1)xk_r, xk_i = xk.float().reshape(*xk.shape[:-1], -1, 2).unbind(-1)# 提取频率的实部和虚部freqs_cos = freqs_cis.real()freqs_sin = freqs_cis.imag()# 应用旋转变换# xq_out = xq * cos(θ) + rotate_half(xq) * sin(θ)xq_out_r = xq_r * freqs_cos - xq_i * freqs_sinxq_out_i = xq_r * freqs_sin + xq_i * freqs_cosxk_out_r = xk_r * freqs_cos - xk_i * freqs_sinxk_out_i = xk_r * freqs_sin + xk_i * freqs_cos# 重新组合实部和虚部xq_out = torch.stack([xq_out_r, xq_out_i], dim=-1).flatten(-2)xk_out = torch.stack([xk_out_r, xk_out_i], dim=-1).flatten(-2)return xq_out.type_as(xq), xk_out.type_as(xk)

2.2 动态NTK外推方案

动态NTK缩放是实现长上下文的关键技术,它通过动态调整位置编码的缩放因子来改善模型在更长序列上的表现:

class LlamaConfig:def __init__(self):self.rope_scaling = {"type": "dynamic",  # 动态缩放类型"factor": 2.0,      # 基础缩放因子"original_max_position_embeddings": 2048  # 原始训练长度}def compute_dynamic_ntk_scaling(ctx_len: int,orig_ctx_len: int = 2048,base_scale: float = 0.25,alpha: float = 1.0
) -> float:"""计算动态NTK缩放因子Args:ctx_len: 当前上下文长度orig_ctx_len: 原始训练上下文长度base_scale: 基础缩放系数alpha: 缩放曲线的陡峭程度"""# 使用对数曲线计算缩放因子return base_scale * math.log(ctx_len / orig_ctx_len) ** alphaclass LlamaAttention(nn.Module):def __init__(self, config: LlamaConfig):super().__init__()self.config = configself.rope_scaling = config.rope_scalingdef forward(self,hidden_states: torch.Tensor,attention_mask: Optional[torch.Tensor] = None,position_ids: Optional[torch.LongTensor] = None,) -> torch.Tensor:"""注意力前向计算Args:hidden_states: 输入张量 [batch_size, seq_len, hidden_size]attention_mask: 注意力掩码position_ids: 位置索引"""seq_len = hidden_states.shape[1]# 计算动态缩放因子if self.rope_scaling["type"] == "dynamic":rope_scale = compute_dynamic_ntk_scaling(seq_len,self.config.rope_scaling["original_max_position_embeddings"],base_scale=self.rope_scaling["factor"])else:rope_scale = 1.0# 计算位置编码freqs_cis = precompute_freqs_cis(self.head_dim,seq_len,scale=rope_scale)# 应用旋转位置编码query_states, key_states = apply_rotary_emb(self.q_proj(hidden_states),self.k_proj(hidden_states),freqs_cis)

3. 注意力机制优化

3.1 分块注意力计算

为了高效处理长序列,Llama实现了分块注意力计算。以下是详细的实现代码:

class ChunkedAttention(nn.Module):def __init__(self, chunk_size: int = 1024):super().__init__()self.chunk_size = chunk_sizedef forward(self,query: torch.Tensor,      # [batch, num_heads, seq_len, head_dim]key: torch.Tensor,        # [batch, num_heads, seq_len, head_dim]value: torch.Tensor,      # [batch, num_heads, seq_len, head_dim]mask: Optional[torch.Tensor] = None) -> torch.Tensor:"""分块计算注意力"""batch_size, num_heads, seq_len, head_dim = query.shape# 计算需要的块数num_chunks = (seq_len + self.chunk_size - 1) // self.chunk_size# 存储每个块的输出chunked_outputs = []# 按块计算注意力for chunk_idx in range(num_chunks):# 计算当前块的起止位置chunk_start = chunk_idx * self.chunk_sizechunk_end = min(chunk_start + self.chunk_size, seq_len)# 提取当前块的querychunk_query = query[:, :, chunk_start:chunk_end]# 计算注意力得分chunk_scores = torch.matmul(chunk_query,                    # [b, h, chunk_size, d]key.transpose(-2, -1)           # [b, h, d, seq_len])   # 得到 [b, h, chunk_size, seq_len]# 缩放注意力得分chunk_scores = chunk_scores / math.sqrt(head_dim)# 应用attention maskif mask is not None:chunk_mask = mask[:, :, chunk_start:chunk_end, :]chunk_scores = chunk_scores + chunk_mask# 应用softmaxchunk_attn = F.softmax(chunk_scores, dim=-1)# 计算输出chunk_output = torch.matmul(chunk_attn, value)chunked_outputs.append(chunk_output)# 拼接所有块的输出return torch.cat(chunked_outputs, dim=2)

3.2 优化的KV Cache实现

KV Cache的实现需要考虑内存效率和计算性能:

class KVCache:def __init__(self,max_batch_size: int,max_seq_length: int,num_heads: int,head_dim: int,dtype: torch.dtype = torch.float16):"""初始化KV缓存Args:max_batch_size: 最大批次大小max_seq_length: 最大序列长度num_heads: 注意力头数head_dim: 每个头的维度dtype: 数据类型"""self.max_seq_length = max_seq_length# 初始化缓存张量self.k_cache = torch.zeros(max_batch_size,num_heads,max_seq_length,head_dim,dtype=dtype)self.v_cache = torch.zeros(max_batch_size,num_heads,max_seq_length,head_dim,dtype=dtype)# 记录当前序列长度self.current_length = 0def update(self,key: torch.Tensor,value: torch.Tensor,position: int) -> None:"""更新缓存Args:key: key状态 [batch_size, num_heads, seq_len, head_dim]value: value状态 [batch_size, num_heads, seq_len, head_dim]position: 起始位置"""seq_len = key.shape[2]if position + seq_len > self.max_seq_length:raise ValueError(f"Position {position + seq_len} exceeds max_seq_length {self.max_seq_length}")# 更新缓存self.k_cache[:, :, position:position+seq_len] = keyself.v_cache[:, :, position:position+seq_len] = value# 更新当前长度self.current_length = max(self.current_length, position + seq_len)def get_cached_kv(self,start_pos: int,end_pos: int) -> Tuple[torch.Tensor, torch.Tensor]:"""获取指定范围的缓存内容"""return (self.k_cache[:, :, start_pos:end_pos],self.v_cache[:, :, start_pos:end_pos])def clear(self) -> None:"""清空缓存"""self.k_cache.zero_()self.v_cache.zero_()self.current_length = 0

4. 实际应用示例

让我们看一个完整的使用示例,展示如何处理长文本:

class LongContextProcessor:def __init__(self,model: LlamaModel,tokenizer,max_length: int = 16384,chunk_size: int = 1024):self.model = modelself.tokenizer = tokenizerself.chunk_size = chunk_size# 初始化KV缓存self.kv_cache = KVCache(max_batch_size=1,max_seq_length=max_length,num_heads=model.config.num_attention_heads,head_dim=model.config.hidden_size // model.config.num_attention_heads)def process_long_text(self, text: str) -> torch.Tensor:"""处理长文本输入Args:text: 输入文本Returns:处理后的隐藏状态"""# 分词tokens = self.tokenizer(text,return_tensors="pt",truncation=False).input_ids# 清空KV缓存self.kv_cache.clear()# 分块处理all_hidden_states = []for i in range(0, tokens.size(1), self.chunk_size):# 获取当前块chunk = tokens[:, i:i+self.chunk_size]# 获取位置编码索引position_ids = torch.arange(i,i + chunk.size(1),dtype=torch.long,device=chunk.device).unsqueeze(0)# 获取当前位置的缓存k_cache, v_cache = self.kv_cache.get_cached_kv(0, i)# 前向计算outputs = self.model(chunk,position_ids=position_ids,past_key_values=[(k_cache, v_cache)] * self.model.config.num_hidden_layers)# 更新缓存self.kv

http://www.ppmy.cn/news/1569936.html

相关文章

【力扣】53.最大子数组和

AC截图 题目 思路 这道题主要考虑的就是要排除负数带来的负面影响。如果遍历数组,那么应该有如下关系式: currentAns max(prenums[i],nums[i]) pre是之前记录的最大和,如果prenums[i]小于nums[i],就要考虑舍弃pre,从…

Swift 进阶:Observation 框架中可观察(@Observable)对象的高级操作(上)

概述 在 WWDC 24 中苹果推出了全新的 Observation 框架,借助于它我们可以更加细粒度的监听可观察(@Observable)对象 。同时,SwiftUI 自身也与时偕行开始全面支持 @Observable 对象的“嵌入”。 然而在这里,我们却另辟蹊径来介绍 @Observable 对象另外一些“鲜为人知”的故…

【C++】STL——vector底层实现

目录 💕 1.vector三个核心 💕2.begin函数,end函数的实现(简单略讲) 💕3.size函数,capacity函数的实现 (简单略讲) 💕4.reserve函数实现 (细节…

机器学习7-全连接神经网络3-过拟合与超参数

机器学习7-全连接神经网络3-过拟合欠拟合 过拟合应对过拟合-最优方案:获取更多的训练数据应对过拟合-次优方案:正则化应对过拟合-次优方案2:随机失活综合考量 超参数超参数优化方法 过拟合 机器学习的根本问题是优化和泛化的问题。优化——是…

RabbitMQ深度探索:简单实现 MQ

基于多线程队列实现 MQ &#xff1a; 实现类&#xff1a; public class ThreadMQ {private static LinkedBlockingDeque<JSONObject> broker new LinkedBlockingDeque<JSONObject>();public static void main(String[] args) {//创建生产者线程Thread producer n…

Python从0到100(八十七):CNN网络详细介绍及WISDM数据集模型仿真

前言&#xff1a; 零基础学Python&#xff1a;Python从0到100最新最全教程。 想做这件事情很久了&#xff0c;这次我更新了自己所写过的所有博客&#xff0c;汇集成了Python从0到100&#xff0c;共一百节课&#xff0c;帮助大家一个月时间里从零基础到学习Python基础语法、Pyth…

Chromium132 编译指南 - Android 篇(四):配置 depot_tools

1. 引言 在前面的章节中&#xff0c;我们详细介绍了编译 Chromium 132 for Android 所需的系统和硬件要求&#xff0c;以及如何安装和配置基础开发环境和常用工具。完成这些步骤后&#xff0c;接下来需要配置 depot_tools&#xff0c;这是编译 Chromium 的关键工具集。depot_t…

【PromptCoder + Bolt.new】Cascade模式自动生成页面和对应的路由

【PromptCoder Bolt.new】Cascade模式自动生成页面和对应的路由 官网&#xff1a;PromptCoder PromptCoder&#xff1a;智能代码提示词生成 PromptCoder是一款利用人工智能技术的智能代码生成工具。它能够识别设计图或截图&#xff0c;并自动生成与之匹配的前端代码。无论是…