1. 引言
Llama模型的一个重要特性是支持长上下文处理。本文将深入分析Llama源码中实现长上下文的关键技术点,包括位置编码(position embedding)的外推方法、注意力机制的优化等。我们将通过详细的代码解析来理解其实现原理。
2. 位置编码的外推实现
2.1 旋转位置编码(RoPE)基础
Llama采用旋转位置编码(RoPE, Rotary Position Embedding)来编码token的位置信息。RoPE的实现包含几个关键步骤:
def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0, scale: float = 1.0):"""预计算RoPE的频率Args:dim: 隐藏层维度end: 序列最大长度theta: RoPE的基频参数scale: 位置缩放因子Returns:freqs_cis: 复数形式的频率矩阵"""# 生成维度序列 [0, 2, ..., dim-2]dims = torch.arange(0, dim, 2)[: (dim // 2)].float()# 计算频率基数 1/θ^(2i/d)freqs = 1.0 / (theta ** (dims / dim))# 生成位置序列并应用缩放t = torch.arange(end, device=freqs.device) * scale# 计算位置和频率的外积freqs = torch.outer(t, freqs)# 转换为复数形式 e^(iθ)freqs_cis = torch.polar(torch.ones_like(freqs), freqs)return freqs_cisdef apply_rotary_emb(xq: torch.Tensor,xk: torch.Tensor,freqs_cis: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:"""应用旋转位置编码Args:xq: query张量 [batch_size, seq_len, num_heads, head_dim]xk: key张量 [batch_size, seq_len, num_heads, head_dim]freqs_cis: 预计算的频率 [seq_len, head_dim//2]"""# 重塑张量以方便运算xq_r, xq_i = xq.float().reshape(*xq.shape[:-1], -1, 2).unbind(-1)xk_r, xk_i = xk.float().reshape(*xk.shape[:-1], -1, 2).unbind(-1)# 提取频率的实部和虚部freqs_cos = freqs_cis.real()freqs_sin = freqs_cis.imag()# 应用旋转变换# xq_out = xq * cos(θ) + rotate_half(xq) * sin(θ)xq_out_r = xq_r * freqs_cos - xq_i * freqs_sinxq_out_i = xq_r * freqs_sin + xq_i * freqs_cosxk_out_r = xk_r * freqs_cos - xk_i * freqs_sinxk_out_i = xk_r * freqs_sin + xk_i * freqs_cos# 重新组合实部和虚部xq_out = torch.stack([xq_out_r, xq_out_i], dim=-1).flatten(-2)xk_out = torch.stack([xk_out_r, xk_out_i], dim=-1).flatten(-2)return xq_out.type_as(xq), xk_out.type_as(xk)
2.2 动态NTK外推方案
动态NTK缩放是实现长上下文的关键技术,它通过动态调整位置编码的缩放因子来改善模型在更长序列上的表现:
class LlamaConfig:def __init__(self):self.rope_scaling = {"type": "dynamic", # 动态缩放类型"factor": 2.0, # 基础缩放因子"original_max_position_embeddings": 2048 # 原始训练长度}def compute_dynamic_ntk_scaling(ctx_len: int,orig_ctx_len: int = 2048,base_scale: float = 0.25,alpha: float = 1.0
) -> float:"""计算动态NTK缩放因子Args:ctx_len: 当前上下文长度orig_ctx_len: 原始训练上下文长度base_scale: 基础缩放系数alpha: 缩放曲线的陡峭程度"""# 使用对数曲线计算缩放因子return base_scale * math.log(ctx_len / orig_ctx_len) ** alphaclass LlamaAttention(nn.Module):def __init__(self, config: LlamaConfig):super().__init__()self.config = configself.rope_scaling = config.rope_scalingdef forward(self,hidden_states: torch.Tensor,attention_mask: Optional[torch.Tensor] = None,position_ids: Optional[torch.LongTensor] = None,) -> torch.Tensor:"""注意力前向计算Args:hidden_states: 输入张量 [batch_size, seq_len, hidden_size]attention_mask: 注意力掩码position_ids: 位置索引"""seq_len = hidden_states.shape[1]# 计算动态缩放因子if self.rope_scaling["type"] == "dynamic":rope_scale = compute_dynamic_ntk_scaling(seq_len,self.config.rope_scaling["original_max_position_embeddings"],base_scale=self.rope_scaling["factor"])else:rope_scale = 1.0# 计算位置编码freqs_cis = precompute_freqs_cis(self.head_dim,seq_len,scale=rope_scale)# 应用旋转位置编码query_states, key_states = apply_rotary_emb(self.q_proj(hidden_states),self.k_proj(hidden_states),freqs_cis)
3. 注意力机制优化
3.1 分块注意力计算
为了高效处理长序列,Llama实现了分块注意力计算。以下是详细的实现代码:
class ChunkedAttention(nn.Module):def __init__(self, chunk_size: int = 1024):super().__init__()self.chunk_size = chunk_sizedef forward(self,query: torch.Tensor, # [batch, num_heads, seq_len, head_dim]key: torch.Tensor, # [batch, num_heads, seq_len, head_dim]value: torch.Tensor, # [batch, num_heads, seq_len, head_dim]mask: Optional[torch.Tensor] = None) -> torch.Tensor:"""分块计算注意力"""batch_size, num_heads, seq_len, head_dim = query.shape# 计算需要的块数num_chunks = (seq_len + self.chunk_size - 1) // self.chunk_size# 存储每个块的输出chunked_outputs = []# 按块计算注意力for chunk_idx in range(num_chunks):# 计算当前块的起止位置chunk_start = chunk_idx * self.chunk_sizechunk_end = min(chunk_start + self.chunk_size, seq_len)# 提取当前块的querychunk_query = query[:, :, chunk_start:chunk_end]# 计算注意力得分chunk_scores = torch.matmul(chunk_query, # [b, h, chunk_size, d]key.transpose(-2, -1) # [b, h, d, seq_len]) # 得到 [b, h, chunk_size, seq_len]# 缩放注意力得分chunk_scores = chunk_scores / math.sqrt(head_dim)# 应用attention maskif mask is not None:chunk_mask = mask[:, :, chunk_start:chunk_end, :]chunk_scores = chunk_scores + chunk_mask# 应用softmaxchunk_attn = F.softmax(chunk_scores, dim=-1)# 计算输出chunk_output = torch.matmul(chunk_attn, value)chunked_outputs.append(chunk_output)# 拼接所有块的输出return torch.cat(chunked_outputs, dim=2)
3.2 优化的KV Cache实现
KV Cache的实现需要考虑内存效率和计算性能:
class KVCache:def __init__(self,max_batch_size: int,max_seq_length: int,num_heads: int,head_dim: int,dtype: torch.dtype = torch.float16):"""初始化KV缓存Args:max_batch_size: 最大批次大小max_seq_length: 最大序列长度num_heads: 注意力头数head_dim: 每个头的维度dtype: 数据类型"""self.max_seq_length = max_seq_length# 初始化缓存张量self.k_cache = torch.zeros(max_batch_size,num_heads,max_seq_length,head_dim,dtype=dtype)self.v_cache = torch.zeros(max_batch_size,num_heads,max_seq_length,head_dim,dtype=dtype)# 记录当前序列长度self.current_length = 0def update(self,key: torch.Tensor,value: torch.Tensor,position: int) -> None:"""更新缓存Args:key: key状态 [batch_size, num_heads, seq_len, head_dim]value: value状态 [batch_size, num_heads, seq_len, head_dim]position: 起始位置"""seq_len = key.shape[2]if position + seq_len > self.max_seq_length:raise ValueError(f"Position {position + seq_len} exceeds max_seq_length {self.max_seq_length}")# 更新缓存self.k_cache[:, :, position:position+seq_len] = keyself.v_cache[:, :, position:position+seq_len] = value# 更新当前长度self.current_length = max(self.current_length, position + seq_len)def get_cached_kv(self,start_pos: int,end_pos: int) -> Tuple[torch.Tensor, torch.Tensor]:"""获取指定范围的缓存内容"""return (self.k_cache[:, :, start_pos:end_pos],self.v_cache[:, :, start_pos:end_pos])def clear(self) -> None:"""清空缓存"""self.k_cache.zero_()self.v_cache.zero_()self.current_length = 0
4. 实际应用示例
让我们看一个完整的使用示例,展示如何处理长文本:
class LongContextProcessor:def __init__(self,model: LlamaModel,tokenizer,max_length: int = 16384,chunk_size: int = 1024):self.model = modelself.tokenizer = tokenizerself.chunk_size = chunk_size# 初始化KV缓存self.kv_cache = KVCache(max_batch_size=1,max_seq_length=max_length,num_heads=model.config.num_attention_heads,head_dim=model.config.hidden_size // model.config.num_attention_heads)def process_long_text(self, text: str) -> torch.Tensor:"""处理长文本输入Args:text: 输入文本Returns:处理后的隐藏状态"""# 分词tokens = self.tokenizer(text,return_tensors="pt",truncation=False).input_ids# 清空KV缓存self.kv_cache.clear()# 分块处理all_hidden_states = []for i in range(0, tokens.size(1), self.chunk_size):# 获取当前块chunk = tokens[:, i:i+self.chunk_size]# 获取位置编码索引position_ids = torch.arange(i,i + chunk.size(1),dtype=torch.long,device=chunk.device).unsqueeze(0)# 获取当前位置的缓存k_cache, v_cache = self.kv_cache.get_cached_kv(0, i)# 前向计算outputs = self.model(chunk,position_ids=position_ids,past_key_values=[(k_cache, v_cache)] * self.model.config.num_hidden_layers)# 更新缓存self.kv