第1章:长上下文技术 (Long Context)#

如何让模型拥有一目十行的"过目不忘"能力?从 RoPE 到 FlashAttention。


目录#


一、长上下文的挑战#

在 RAG 和 Agent 应用中,处理长文本(如 100k tokens 甚至 1M tokens)已成为刚需。但 Transformer 在处理长文本时面临三个核心物理瓶颈:

  1. 计算复杂度 $O(N^2)$:Attention 的计算矩阵是 $N \times N$。序列长度翻倍,计算量增加 4 倍。

    • 4k -> 8k: 计算量增加 4 倍
    • 4k -> 100k: 计算量增加 625 倍!
  2. KV Cache 显存爆炸:推理时需要存储所有历史 Token 的 KV 状态。

    • LLaMA-2-7B (fp16), 4k context: ~2GB KV Cache
    • LLaMA-2-7B (fp16), 100k context: ~50GB KV Cache (单卡 A100 80G 直接撑爆)
  3. 位置编码的外推性 (Extrapolation):训练时只见过 4k 长度,测试时给它 100k,位置编码会"乱套"。模型在超出训练长度后,PPL(困惑度)会急剧上升,开始胡言乱语。


二、位置编码的进化:RoPE (Rotary Positional Embeddings)#

1. 绝对位置 vs 相对位置#

在 Transformer 早期,使用的是绝对位置编码(Absolute PE):

  • Sinusoidal (Attention is All You Need): $\sin(pos/10000^{2i/d})$
  • Learnable (BERT/GPT): 学习一个 Embedding 矩阵 $P \in \mathbb{R}^{seq \times dim}$

问题:绝对位置编码无法捕捉 token 之间的相对距离。对于 “Cat eats fish”,“Cat” 和 “fish” 距离是 2。如果句子变成 “The Cat eats fish”,距离还是 2,但绝对位置变了(从 1,3 变成了 2,4)。模型需要重新学习这种情况。

相对位置编码 (Relative PE):直接在 Attention Score 计算中加入相对距离 $i-j$ 的信息。但实现复杂,且不容易缓存。

2. RoPE 核心原理#

RoPE (Su et al., 2021) 通过将向量在复平面上旋转,巧妙地融合了绝对位置信息,但内积结果却只与相对位置有关。

核心公式:

$$ f(x, m) = x e^{i m \theta} $$

当计算两个位置 $m$ 和 $n$ 的 Query 和 Key 的内积时:

$$ \langle f(q, m), f(k, n) \rangle = \text{Re}(q e^{i m \theta} \cdot \overline{k e^{i n \theta}}) = \text{Re}(q \bar{k} e^{i(m-n)\theta}) $$

神奇之处:结果只包含 $(m-n)$,即相对距离

这使得 RoPE 具有两个极佳特性:

  1. 平移不变性:无论 token 出现在句子的哪个位置,只要相对距离一样,Attention 分数就一样。
  2. 远程衰减:随着相对距离增加,内积值自然衰减(关注近处多于远处)。

3. PyTorch 实现 RoPE#

这是 LLaMA 官方实现的核心代码:

import torch

def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0):
    """
    预计算旋转角度(复数形式)

    Args:
        dim: head_dim (注意不是 hidden_size)
        end: 最大序列长度 max_seq_len
        theta: 基频 (LLaMA 1用10000, LLaMA 3用500000)
    """
    freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
    t = torch.arange(end)  # 位置索引 [0, 1, ..., end-1]

    # 外积计算所有位置的所有频率
    freqs = torch.outer(t, freqs).float()  # [end, dim//2]

    # 转为复数 e^{i*freqs} = cos(freqs) + i*sin(freqs)
    freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
    return freqs_cis

def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
    ndim = x.ndim
    assert 0 <= 1 < ndim
    assert freqs_cis.shape == (x.shape[1], x.shape[-1])
    shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
    return freqs_cis.view(*shape)

def apply_rotary_emb(xq: torch.Tensor, xk: torch.Tensor, freqs_cis: torch.Tensor):
    """
    应用 RoPE 旋转

    Args:
        xq: Query [batch, seq_len, n_heads, head_dim]
        xk: Key   [batch, seq_len, n_kv_heads, head_dim]
        freqs_cis: 预计算的复数频率
    """
    # 将 Q, K 重塑为复数张量 (把最后一维 dim 拆成 dim/2 个复数)
    xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
    xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))

    # 广播形状以匹配
    freqs_cis = reshape_for_broadcast(freqs_cis, xq_)

    # 复数乘法(即旋转): (a+bi)(c+di) = (ac-bd) + i(ad+bc)
    xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
    xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)

    return xq_out.type_as(xq), xk_out.type_as(xk)

实战 Tip

  • theta 的选择至关重要。theta 越大,波长越长,能表示的相对距离越远。
  • LLaMA-1 (2k context): theta = 10000
  • LLaMA-2 (4k context): theta = 10000
  • CodeLLaMA (100k context): theta = 1000000
  • LLaMA-3 (8k context): theta = 500000

三、外推技术:打破长度限制#

如果模型训练时最大长度是 4096 (4k),如何让它在推理时处理 32k 甚至 100k 的文本?

1. 线性内插 (Linear Interpolation)#

问题:直接外推(Extrapolation)效果很差。因为高频位置编码旋转太快,超出训练分布。

思路:把 32k 的长度"压缩"回 4k 的范围内。即欺骗模型

$$ m’ = m \times \frac{L_{train}}{L_{test}} $$

例如要扩展 8 倍,就让位置 0, 1, 2, …, 32 变成 0, 0.125, 0.25, …, 4。

优点:非常稳定,不用重新训练模型就能跑起来(虽然效果会打折,但比直接崩了强)。 缺点:对于高频特征(关注局部信息的 Attention Head),距离被强行压缩了,导致分辨率下降(“近视眼”)。

代码实现: 只需要在计算 freqs 时除以 scale 因子。

# Linear Scaling
scale = 8.0  # 4k -> 32k
t = torch.arange(end) / scale

2. NTK-Aware Scaled RoPE#

这是著名的"Reddit 网友"发现的改进方案。后来被证实与神经正切核 (Neural Tangent Kernel) 理论有关。

核心思想: 低频分量和高频分量应该区别对待。

  • 高频分量(捕捉局部关系):保持不变,不进行插值。因为局部关系(“of the”, “in a”)在长文中也不会变。
  • 低频分量(捕捉长程关系):进行插值,适应更长距离。

公式实现: 不修改位置索引 $t$,而是修改基频 $base$ (theta)。

$$ \text{Base}’ = \text{Base} \times \alpha^{\frac{dim}{dim-2}} $$

def get_ntk_base(scale: float, dim: int, base: float = 10000.0):
    """
    计算 NTK 修正后的 Base

    Args:
        scale: 扩展倍数 (e.g., 8)
        dim: head_dim
    """
    # 核心公式:base = base * scale ^ (dim / (dim-2))
    new_base = base * (scale ** (dim / (dim - 2)))
    return new_base

# 使用新的 base 计算 freqs
freqs = 1.0 / (new_base ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))

效果:不微调的情况下,NTK 插值的 PPL 显著优于线性插值。

3. YaRN (Yet another RoPE for Transformers)#

YaRN 是目前最先进的外推方法之一(DeepSeek-V2, LLaMA-3 都在用类似思想)。

它结合了:

  1. NTK-aware 插值:分频段处理。

  2. Attention Logit 修正: 当上下文变长时,Attention 分布会变得更平滑(Entropy 增加),导致模型注意力涣散。 YaRN 引入一个温度系数 $\sqrt{t}$ 来锐化 Attention:

    $$ \text{Attention}(Q, K, V) = \text{Softmax}(\frac{QK^T}{\sqrt{d} \cdot t})V $$


四、工程优化:FlashAttention#

算法层面解决了位置编码,计算层面还得靠 FlashAttention。它是大模型训练和推理的基础设施

1. 显存带宽瓶颈 (Memory Bound)#

在 GPU 中:

  • HBM (High Bandwidth Memory): 显存,大但慢 (80GB, 2TB/s)
  • SRAM: 类似 L1/L2 Cache,极快但极小 (20MB, 19TB/s)

标准的 Attention 计算: $$S = QK^T \rightarrow P = \text{Softmax}(S) \rightarrow O = PV$$

需要反复将巨大的 $N \times N$ 矩阵在 HBM 和 SRAM 之间搬运:

  1. 读 Q, K -> 算 $S$ -> 写回 HBM
  2. 读 $S$ -> 算 Softmax -> 写回 HBM
  3. 读 $P, V$ -> 算 $O$ -> 写回 HBM

痛点:$N \times N$ 矩阵太大了,根本塞不进 SRAM。而且大部分时间 GPU 核心在等 HBM 搬数据(IO 瓶颈)。

2. FlashAttention V1: Tiling & Recomputation#

FlashAttention (Dao et al., 2022) 的核心魔法是 Tiling (分块)

算法流程

  1. 将 $Q, K, V$ 切分成小块(Block),比如 $128 \times 128$。

  2. 每次只加载一部分块到 SRAM。

  3. 在 SRAM 中计算局部的 Attention Score。

  4. Online Softmax:利用数学技巧,不需要一次性看到所有分数就能计算 Softmax 的归一化因子。

    $$ m_{new} = \max(m_{old}, \max(x_{new})) $$ $$ l_{new} = l_{old} \cdot e^{m_{old}-m_{new}} + \sum e^{x_{new}-m_{new}} $$

  5. 直接在 SRAM 中算完 $O$ 的一部分,只把最终结果写回 HBM。

收益

  • 显存占用:从 $O(N^2)$ 降为 $O(N)$(线性!)。不再需要存储 $N \times N$ 的 Attention Map。
  • 速度:加速 2-4 倍(减少了 HBM 访问次数)。

3. FlashAttention V2: 并行优化#

FlashAttention V2 (Dao, 2023) 进一步优化:

  1. 减少非矩阵运算:把 Softmax 等操作尽量融入矩阵各乘法 (MatMul) 中。
  2. 更好的并行化
    • V1 主要是按 Batch 和 Head 并行。
    • V2 增加了按 Sequence Length 并行(即使 batch size=1 也能占满 GPU)。

实战代码 (使用 PyTorch 2.0+)

现在 PyTorch 2.0 已经内置了 FlashAttention(称为 Scaled Dot Product Attention, SDPA)。

import torch
import torch.nn.functional as F

# 启用 FlashAttention
with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=False, enable_mem_efficient=False):
    output = F.scaled_dot_product_attention(
        query, key, value,
        attn_mask=None,
        dropout_p=0.0,
        is_causal=True
    )

五、显存优化技术#

1. PagedAttention (vLLM)#

随着 Context 变长,KV Cache 成为显存杀手。 传统的 KV Cache 是预分配连续显存的。如果 max_len=2048,即使用户只输入 5 个字,系统也会预留 2048 的槽位(或者产生大量碎片)。

PagedAttention 灵感来自操作系统的 虚拟内存 (Virtual Memory)

  • 把 KV Cache 切分成固定大小的 Block (e.g., 16 tokens/block)。
  • 逻辑上连续的 token,在显存物理上可以不连续。
  • 通过 Block Table 记录映射关系。

优势

  • 零浪费:显存利用率接近 100%。
  • 动态分配:生成多少用多少。
  • Copy-on-Write:多个请求共享 Prompt 的 KV Cache(如 System Prompt)。

2. KV Cache Quantization#

将 KV Cache 从 FP16 (2 bytes) 压缩到 INT8 (1 byte) 甚至 INT4。

  • FP16: 2 * 2 * L * H * D bytes
  • INT8: 1 * 2 * L * H * D bytes (省一半显存)

KIVI (2024) 等算法证明,KV Cache 即使量化到 2-bit,对精度影响也很小。

3. Grouped-Query Attention (GQA)#

LLaMA-2 和 LLaMA-3 都使用了 GQA。

  • MHA (Multi-Head): Query heads = KV heads (1:1)。KV Cache 最大。
  • MQA (Multi-Query): 只有 1 个 KV head,所有 Query 共享。KV Cache 最小,但掉点明显。
  • GQA (Grouped-Query): 折中方案。比如 32 个 Query head,8 个 KV head (4:1)。

GQA 在保持高性能的同时,将 KV Cache 显存降低了 4-8 倍。


六、代码实战:手写一个支持 32k 上下文的 Mini-Llama#

为了彻底理解,我们实现一个带有 RoPEFlashAttention 的注意力层。

import torch
import torch.nn as nn
import torch.nn.functional as F
import math

class LlamaRotaryEmbedding(nn.Module):
    def __init__(self, dim, max_position_embeddings=32768, base=10000, device=None):
        super().__init__()
        self.dim = dim
        self.max_position_embeddings = max_position_embeddings
        self.base = base

        # 预计算 cos/sin
        inv_freq = 1.0 / (self.base ** (torch.arange(0, dim, 2).float().to(device) / dim))
        self.register_buffer("inv_freq", inv_freq)

        self.update_freqs(max_position_embeddings, device)

    def update_freqs(self, seq_len, device):
        t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype)
        freqs = torch.einsum("i,j->ij", t, self.inv_freq)
        emb = torch.cat((freqs, freqs), dim=-1)
        self.cos_cached = emb.cos()[None, None, :, :]
        self.sin_cached = emb.sin()[None, None, :, :]

    def forward(self, x, seq_len=None):
        # x: [bs, num_attention_heads, seq_len, head_size]
        if seq_len > self.max_position_embeddings:
            # 动态扩展(简单线性外推)
            self.update_freqs(seq_len, x.device)

        return (
            self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
            self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
        )

def rotate_half(x):
    """Rotates half the hidden dims of the input."""
    x1 = x[..., : x.shape[-1] // 2]
    x2 = x[..., x.shape[-1] // 2 :]
    return torch.cat((-x2, x1), dim=-1)

def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
    # 简单的实现版本,没有使用复数
    # q, k: [bs, num_heads, seq_len, head_dim]
    q_embed = (q * cos) + (rotate_half(q) * sin)
    k_embed = (k * cos) + (rotate_half(k) * sin)
    return q_embed, k_embed

class LongContextAttention(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.hidden_size = config.hidden_size
        self.num_heads = config.num_heads
        self.head_dim = self.hidden_size // self.num_heads

        self.q_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
        self.k_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
        self.v_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
        self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=False)

        self.rotary_emb = LlamaRotaryEmbedding(self.head_dim)

    def forward(self, x, attention_mask=None):
        bsz, seq_len, _ = x.shape

        # 1. 投影
        q = self.q_proj(x).view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        k = self.k_proj(x).view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        v = self.v_proj(x).view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2)

        # 2. 应用 RoPE
        cos, sin = self.rotary_emb(v, seq_len=seq_len)
        q, k = apply_rotary_pos_emb(q, k, cos, sin, None)

        # 3. FlashAttention
        # 自动选择最优实现(FlashV2 > MemEfficient > Math)
        context_layer = F.scaled_dot_product_attention(
            q, k, v,
            attn_mask=attention_mask,
            dropout_p=0.0 if not self.training else 0.1,
            is_causal=True
        )

        # 4. 输出
        context_layer = context_layer.transpose(1, 2).contiguous().view(bsz, seq_len, self.hidden_size)
        output = self.o_proj(context_layer)

        return output

# 测试代码
device = "cuda" if torch.cuda.is_available() else "cpu"
config = type('Config', (), {'hidden_size': 4096, 'num_heads': 32})()
attn = LongContextAttention(config).to(device)

x = torch.randn(1, 1024, 4096).to(device)
with torch.no_grad():
    out = attn(x)
    print(f"输入形状: {x.shape}")
    print(f"输出形状: {out.shape}")  # 应该是 [1, 1024, 4096]
    print("长上下文 Attention 计算成功!")

七、本章小结#

长上下文技术是构建 Agent 记忆系统和大型 RAG 知识库的基础。

  1. RoPE: 完美的相对位置编码,是 LLaMA 家族的标配。
  2. NTK/YaRN: “不重新训练模型"就能把 Context 窗口拉长 4-8 倍的魔法。
  3. FlashAttention: 打破 IO 瓶颈,让 Attention 计算速度跟上 GPU 算力。
  4. PagedAttention: 像管理内存一样管理显存,解决碎片化问题。

掌握这些技术,你就不再会被 “Context Window exceeded” 报错所困扰。


下一章预告: 第2章 - 新型架构探索

在下一章中,我们将拆解 Mixtral 8x7B 和 DeepSeek-MoE 背后的稀疏激活机制,以及 DeepSeek-V3 的 MLA 架构。

[统计组件仅在生产环境显示]