第1章:长上下文技术 (Long Context)#
如何让模型拥有一目十行的"过目不忘"能力?从 RoPE 到 FlashAttention。
目录#
- 一、长上下文的挑战
- 二、位置编码的进化:RoPE (Rotary Positional Embeddings)
- 三、外推技术:打破长度限制
- 四、工程优化:FlashAttention
- 五、显存优化技术
- 六、代码实战:手写一个支持 32k 上下文的 Mini-Llama
- 七、本章小结
一、长上下文的挑战#
在 RAG 和 Agent 应用中,处理长文本(如 100k tokens 甚至 1M tokens)已成为刚需。但 Transformer 在处理长文本时面临三个核心物理瓶颈:
计算复杂度 $O(N^2)$:Attention 的计算矩阵是 $N \times N$。序列长度翻倍,计算量增加 4 倍。
- 4k -> 8k: 计算量增加 4 倍
- 4k -> 100k: 计算量增加 625 倍!
KV Cache 显存爆炸:推理时需要存储所有历史 Token 的 KV 状态。
- LLaMA-2-7B (fp16), 4k context: ~2GB KV Cache
- LLaMA-2-7B (fp16), 100k context: ~50GB KV Cache (单卡 A100 80G 直接撑爆)
位置编码的外推性 (Extrapolation):训练时只见过 4k 长度,测试时给它 100k,位置编码会"乱套"。模型在超出训练长度后,PPL(困惑度)会急剧上升,开始胡言乱语。
二、位置编码的进化:RoPE (Rotary Positional Embeddings)#
1. 绝对位置 vs 相对位置#
在 Transformer 早期,使用的是绝对位置编码(Absolute PE):
- Sinusoidal (Attention is All You Need): $\sin(pos/10000^{2i/d})$
- Learnable (BERT/GPT): 学习一个 Embedding 矩阵 $P \in \mathbb{R}^{seq \times dim}$
问题:绝对位置编码无法捕捉 token 之间的相对距离。对于 “Cat eats fish”,“Cat” 和 “fish” 距离是 2。如果句子变成 “The Cat eats fish”,距离还是 2,但绝对位置变了(从 1,3 变成了 2,4)。模型需要重新学习这种情况。
相对位置编码 (Relative PE):直接在 Attention Score 计算中加入相对距离 $i-j$ 的信息。但实现复杂,且不容易缓存。
2. RoPE 核心原理#
RoPE (Su et al., 2021) 通过将向量在复平面上旋转,巧妙地融合了绝对位置信息,但内积结果却只与相对位置有关。
核心公式:
$$ f(x, m) = x e^{i m \theta} $$
当计算两个位置 $m$ 和 $n$ 的 Query 和 Key 的内积时:
$$ \langle f(q, m), f(k, n) \rangle = \text{Re}(q e^{i m \theta} \cdot \overline{k e^{i n \theta}}) = \text{Re}(q \bar{k} e^{i(m-n)\theta}) $$
神奇之处:结果只包含 $(m-n)$,即相对距离!
这使得 RoPE 具有两个极佳特性:
- 平移不变性:无论 token 出现在句子的哪个位置,只要相对距离一样,Attention 分数就一样。
- 远程衰减:随着相对距离增加,内积值自然衰减(关注近处多于远处)。
3. PyTorch 实现 RoPE#
这是 LLaMA 官方实现的核心代码:
import torch
def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0):
"""
预计算旋转角度(复数形式)
Args:
dim: head_dim (注意不是 hidden_size)
end: 最大序列长度 max_seq_len
theta: 基频 (LLaMA 1用10000, LLaMA 3用500000)
"""
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
t = torch.arange(end) # 位置索引 [0, 1, ..., end-1]
# 外积计算所有位置的所有频率
freqs = torch.outer(t, freqs).float() # [end, dim//2]
# 转为复数 e^{i*freqs} = cos(freqs) + i*sin(freqs)
freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
return freqs_cis
def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
ndim = x.ndim
assert 0 <= 1 < ndim
assert freqs_cis.shape == (x.shape[1], x.shape[-1])
shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
return freqs_cis.view(*shape)
def apply_rotary_emb(xq: torch.Tensor, xk: torch.Tensor, freqs_cis: torch.Tensor):
"""
应用 RoPE 旋转
Args:
xq: Query [batch, seq_len, n_heads, head_dim]
xk: Key [batch, seq_len, n_kv_heads, head_dim]
freqs_cis: 预计算的复数频率
"""
# 将 Q, K 重塑为复数张量 (把最后一维 dim 拆成 dim/2 个复数)
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
# 广播形状以匹配
freqs_cis = reshape_for_broadcast(freqs_cis, xq_)
# 复数乘法(即旋转): (a+bi)(c+di) = (ac-bd) + i(ad+bc)
xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
return xq_out.type_as(xq), xk_out.type_as(xk)实战 Tip:
theta的选择至关重要。theta越大,波长越长,能表示的相对距离越远。- LLaMA-1 (2k context): theta = 10000
- LLaMA-2 (4k context): theta = 10000
- CodeLLaMA (100k context): theta = 1000000
- LLaMA-3 (8k context): theta = 500000
三、外推技术:打破长度限制#
如果模型训练时最大长度是 4096 (4k),如何让它在推理时处理 32k 甚至 100k 的文本?
1. 线性内插 (Linear Interpolation)#
问题:直接外推(Extrapolation)效果很差。因为高频位置编码旋转太快,超出训练分布。
思路:把 32k 的长度"压缩"回 4k 的范围内。即欺骗模型。
$$ m’ = m \times \frac{L_{train}}{L_{test}} $$
例如要扩展 8 倍,就让位置 0, 1, 2, …, 32 变成 0, 0.125, 0.25, …, 4。
优点:非常稳定,不用重新训练模型就能跑起来(虽然效果会打折,但比直接崩了强)。 缺点:对于高频特征(关注局部信息的 Attention Head),距离被强行压缩了,导致分辨率下降(“近视眼”)。
代码实现:
只需要在计算 freqs 时除以 scale 因子。
# Linear Scaling
scale = 8.0 # 4k -> 32k
t = torch.arange(end) / scale2. NTK-Aware Scaled RoPE#
这是著名的"Reddit 网友"发现的改进方案。后来被证实与神经正切核 (Neural Tangent Kernel) 理论有关。
核心思想: 低频分量和高频分量应该区别对待。
- 高频分量(捕捉局部关系):保持不变,不进行插值。因为局部关系(“of the”, “in a”)在长文中也不会变。
- 低频分量(捕捉长程关系):进行插值,适应更长距离。
公式实现: 不修改位置索引 $t$,而是修改基频 $base$ (theta)。
$$ \text{Base}’ = \text{Base} \times \alpha^{\frac{dim}{dim-2}} $$
def get_ntk_base(scale: float, dim: int, base: float = 10000.0):
"""
计算 NTK 修正后的 Base
Args:
scale: 扩展倍数 (e.g., 8)
dim: head_dim
"""
# 核心公式:base = base * scale ^ (dim / (dim-2))
new_base = base * (scale ** (dim / (dim - 2)))
return new_base
# 使用新的 base 计算 freqs
freqs = 1.0 / (new_base ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))效果:不微调的情况下,NTK 插值的 PPL 显著优于线性插值。
3. YaRN (Yet another RoPE for Transformers)#
YaRN 是目前最先进的外推方法之一(DeepSeek-V2, LLaMA-3 都在用类似思想)。
它结合了:
NTK-aware 插值:分频段处理。
Attention Logit 修正: 当上下文变长时,Attention 分布会变得更平滑(Entropy 增加),导致模型注意力涣散。 YaRN 引入一个温度系数 $\sqrt{t}$ 来锐化 Attention:
$$ \text{Attention}(Q, K, V) = \text{Softmax}(\frac{QK^T}{\sqrt{d} \cdot t})V $$
四、工程优化:FlashAttention#
算法层面解决了位置编码,计算层面还得靠 FlashAttention。它是大模型训练和推理的基础设施。
1. 显存带宽瓶颈 (Memory Bound)#
在 GPU 中:
- HBM (High Bandwidth Memory): 显存,大但慢 (80GB, 2TB/s)
- SRAM: 类似 L1/L2 Cache,极快但极小 (20MB, 19TB/s)
标准的 Attention 计算: $$S = QK^T \rightarrow P = \text{Softmax}(S) \rightarrow O = PV$$
需要反复将巨大的 $N \times N$ 矩阵在 HBM 和 SRAM 之间搬运:
- 读 Q, K -> 算 $S$ -> 写回 HBM
- 读 $S$ -> 算 Softmax -> 写回 HBM
- 读 $P, V$ -> 算 $O$ -> 写回 HBM
痛点:$N \times N$ 矩阵太大了,根本塞不进 SRAM。而且大部分时间 GPU 核心在等 HBM 搬数据(IO 瓶颈)。
2. FlashAttention V1: Tiling & Recomputation#
FlashAttention (Dao et al., 2022) 的核心魔法是 Tiling (分块)。
算法流程:
将 $Q, K, V$ 切分成小块(Block),比如 $128 \times 128$。
每次只加载一部分块到 SRAM。
在 SRAM 中计算局部的 Attention Score。
Online Softmax:利用数学技巧,不需要一次性看到所有分数就能计算 Softmax 的归一化因子。
$$ m_{new} = \max(m_{old}, \max(x_{new})) $$ $$ l_{new} = l_{old} \cdot e^{m_{old}-m_{new}} + \sum e^{x_{new}-m_{new}} $$
直接在 SRAM 中算完 $O$ 的一部分,只把最终结果写回 HBM。
收益:
- 显存占用:从 $O(N^2)$ 降为 $O(N)$(线性!)。不再需要存储 $N \times N$ 的 Attention Map。
- 速度:加速 2-4 倍(减少了 HBM 访问次数)。
3. FlashAttention V2: 并行优化#
FlashAttention V2 (Dao, 2023) 进一步优化:
- 减少非矩阵运算:把 Softmax 等操作尽量融入矩阵各乘法 (MatMul) 中。
- 更好的并行化:
- V1 主要是按 Batch 和 Head 并行。
- V2 增加了按 Sequence Length 并行(即使 batch size=1 也能占满 GPU)。
实战代码 (使用 PyTorch 2.0+):
现在 PyTorch 2.0 已经内置了 FlashAttention(称为 Scaled Dot Product Attention, SDPA)。
import torch
import torch.nn.functional as F
# 启用 FlashAttention
with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=False, enable_mem_efficient=False):
output = F.scaled_dot_product_attention(
query, key, value,
attn_mask=None,
dropout_p=0.0,
is_causal=True
)五、显存优化技术#
1. PagedAttention (vLLM)#
随着 Context 变长,KV Cache 成为显存杀手。 传统的 KV Cache 是预分配连续显存的。如果 max_len=2048,即使用户只输入 5 个字,系统也会预留 2048 的槽位(或者产生大量碎片)。
PagedAttention 灵感来自操作系统的 虚拟内存 (Virtual Memory):
- 把 KV Cache 切分成固定大小的 Block (e.g., 16 tokens/block)。
- 逻辑上连续的 token,在显存物理上可以不连续。
- 通过 Block Table 记录映射关系。
优势:
- 零浪费:显存利用率接近 100%。
- 动态分配:生成多少用多少。
- Copy-on-Write:多个请求共享 Prompt 的 KV Cache(如 System Prompt)。
2. KV Cache Quantization#
将 KV Cache 从 FP16 (2 bytes) 压缩到 INT8 (1 byte) 甚至 INT4。
- FP16: 2 * 2 * L * H * D bytes
- INT8: 1 * 2 * L * H * D bytes (省一半显存)
KIVI (2024) 等算法证明,KV Cache 即使量化到 2-bit,对精度影响也很小。
3. Grouped-Query Attention (GQA)#
LLaMA-2 和 LLaMA-3 都使用了 GQA。
- MHA (Multi-Head): Query heads = KV heads (1:1)。KV Cache 最大。
- MQA (Multi-Query): 只有 1 个 KV head,所有 Query 共享。KV Cache 最小,但掉点明显。
- GQA (Grouped-Query): 折中方案。比如 32 个 Query head,8 个 KV head (4:1)。
GQA 在保持高性能的同时,将 KV Cache 显存降低了 4-8 倍。
六、代码实战:手写一个支持 32k 上下文的 Mini-Llama#
为了彻底理解,我们实现一个带有 RoPE 和 FlashAttention 的注意力层。
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
class LlamaRotaryEmbedding(nn.Module):
def __init__(self, dim, max_position_embeddings=32768, base=10000, device=None):
super().__init__()
self.dim = dim
self.max_position_embeddings = max_position_embeddings
self.base = base
# 预计算 cos/sin
inv_freq = 1.0 / (self.base ** (torch.arange(0, dim, 2).float().to(device) / dim))
self.register_buffer("inv_freq", inv_freq)
self.update_freqs(max_position_embeddings, device)
def update_freqs(self, seq_len, device):
t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype)
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
emb = torch.cat((freqs, freqs), dim=-1)
self.cos_cached = emb.cos()[None, None, :, :]
self.sin_cached = emb.sin()[None, None, :, :]
def forward(self, x, seq_len=None):
# x: [bs, num_attention_heads, seq_len, head_size]
if seq_len > self.max_position_embeddings:
# 动态扩展(简单线性外推)
self.update_freqs(seq_len, x.device)
return (
self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
)
def rotate_half(x):
"""Rotates half the hidden dims of the input."""
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
# 简单的实现版本,没有使用复数
# q, k: [bs, num_heads, seq_len, head_dim]
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed
class LongContextAttention(nn.Module):
def __init__(self, config):
super().__init__()
self.hidden_size = config.hidden_size
self.num_heads = config.num_heads
self.head_dim = self.hidden_size // self.num_heads
self.q_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
self.k_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
self.v_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
self.rotary_emb = LlamaRotaryEmbedding(self.head_dim)
def forward(self, x, attention_mask=None):
bsz, seq_len, _ = x.shape
# 1. 投影
q = self.q_proj(x).view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
k = self.k_proj(x).view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
v = self.v_proj(x).view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
# 2. 应用 RoPE
cos, sin = self.rotary_emb(v, seq_len=seq_len)
q, k = apply_rotary_pos_emb(q, k, cos, sin, None)
# 3. FlashAttention
# 自动选择最优实现(FlashV2 > MemEfficient > Math)
context_layer = F.scaled_dot_product_attention(
q, k, v,
attn_mask=attention_mask,
dropout_p=0.0 if not self.training else 0.1,
is_causal=True
)
# 4. 输出
context_layer = context_layer.transpose(1, 2).contiguous().view(bsz, seq_len, self.hidden_size)
output = self.o_proj(context_layer)
return output
# 测试代码
device = "cuda" if torch.cuda.is_available() else "cpu"
config = type('Config', (), {'hidden_size': 4096, 'num_heads': 32})()
attn = LongContextAttention(config).to(device)
x = torch.randn(1, 1024, 4096).to(device)
with torch.no_grad():
out = attn(x)
print(f"输入形状: {x.shape}")
print(f"输出形状: {out.shape}") # 应该是 [1, 1024, 4096]
print("长上下文 Attention 计算成功!")七、本章小结#
长上下文技术是构建 Agent 记忆系统和大型 RAG 知识库的基础。
- RoPE: 完美的相对位置编码,是 LLaMA 家族的标配。
- NTK/YaRN: “不重新训练模型"就能把 Context 窗口拉长 4-8 倍的魔法。
- FlashAttention: 打破 IO 瓶颈,让 Attention 计算速度跟上 GPU 算力。
- PagedAttention: 像管理内存一样管理显存,解决碎片化问题。
掌握这些技术,你就不再会被 “Context Window exceeded” 报错所困扰。
下一章预告: 第2章 - 新型架构探索
在下一章中,我们将拆解 Mixtral 8x7B 和 DeepSeek-MoE 背后的稀疏激活机制,以及 DeepSeek-V3 的 MLA 架构。