第3章:推理加速黑科技 (Inference Acceleration)#

本章定位:在不改变模型权重的前提下,让推理速度提升 2-3 倍。核心技术:投机解码(Speculative Decoding)、Medusa 多头预测Lookahead 前瞻解码。这些技术已被集成到 vLLM/TGI/SGLang 等生产系统中。


目录#


1. 自回归解码的性能瓶颈#

1.1 为什么 Transformer 推理这么慢?#

LLM 的生成是自回归 (Autoregressive) 的:每次只能生成一个 Token,必须等上一个 Token 出来才能生成下一个。

生成 "Hello, how are you today?"
Step 1: [] → "Hello"
Step 2: ["Hello"] → ","
Step 3: ["Hello", ","] → "how"
...
Step N: ["Hello", ",", "how", ..., "you"] → "today"

关键问题

  • 每一步都要做一次完整的前向传播 (Forward Pass)。
  • 如果生成 100 个 Token,就要跑 100 次模型。
  • 即使使用了 KV Cache,仍然是顺序依赖

1.2 Batch Size=1 的GPU利用率灾难#

Prefill 阶段 (处理输入 Prompt):

  • 输入长度 N,一次性计算所有 Token 的 Attention。
  • 矩阵乘法维度: [N, d_model] × [d_model, d_model]
  • GPU 利用率: (80%+)。

Decode 阶段 (生成输出):

  • 每次只处理 1 个 Token。
  • 矩阵乘法维度: [1, d_model] × [d_model, d_model]
  • GPU 利用率:极低 (5-15%)。

这意味着在生成阶段,GPU 的大量计算核心都在空转

结论:如果能在一次前向传播中生成多个 Token,就能大幅提升效率。


2. 投机解码 (Speculative Decoding)#

论文: Fast Inference from Transformers via Speculative Decoding (DeepMind, 2023)

2.1 核心思想:草稿模型 + 并行验证#

投机解码使用两个模型:

  1. Draft Model (草稿模型): 小而快 (如 Qwen2-0.5B)。
  2. Target Model (目标模型): 大而准 (如 Qwen2-7B)。

流程:

1. 草稿模型快速生成 K 个候选 Token (K=4-5)
   Draft: ["Hello", "there", "!", "How"]

2. 目标模型一次性验证这 K 个 Token
   - 并行计算 logits(Token_1), logits(Token_2), ...
   - 接受或拒绝每个 Token

3. 如果全部接受 → 跳过 K 步
   如果部分拒绝 → 从拒绝点重新生成

关键优势:

  • 目标模型只需跑 1 次前向传播,而不是 K 次。
  • 即使草稿模型有时猜错,也不影响最终输出质量。

2.2 数学原理:无损加速的保证#

投机解码的核心是概率修正 (Rejection Sampling),确保输出分布与目标模型完全一致。

设:

  • $p(x)$: 目标模型的概率分布
  • $q(x)$: 草稿模型的概率分布
  • $\gamma = \frac{p(x)}{q(x)}$: 重要性权重

接受准则: 对于草稿模型生成的 Token $x_i$,以概率 $\min(1, \gamma)$ 接受。

$$ \text{Accept}(x_i) = \begin{cases} \text{True}, & \text{if } \mathcal{U}(0,1) \leq \min\left(1, \frac{p(x_i \mid x_{<i})}{q(x_i \mid x_{<i})}\right) \ \text{False}, & \text{otherwise} \end{cases} $$

为什么是无损的?

这是经典的拒绝采样 (Rejection Sampling)。可以严格证明:最终采样出的 Token 序列分布 = $p(x)$。

直觉解释:

  • 如果草稿模型的预测 $q(x_i)$ 与目标模型 $p(x_i)$ 接近 → $\gamma \approx 1$ → 高概率接受。
  • 如果草稿模型严重偏离 → $\gamma \ll 1$ → 拒绝,并从目标模型重新采样。

2.3 实战:用 Qwen2-0.5B 加速 Qwen2-7B#

以下是基于 Transformers 的简化实现 (生产环境推荐用 vLLM 的原生支持):

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

class SpeculativeDecoder:
    def __init__(self, draft_model_id, target_model_id, device="cuda"):
        self.draft = AutoModelForCausalLM.from_pretrained(draft_model_id).to(device)
        self.target = AutoModelForCausalLM.from_pretrained(target_model_id).to(device)
        self.tokenizer = AutoTokenizer.from_pretrained(target_model_id)
        self.device = device

    @torch.no_grad()
    def generate(self, prompt, max_new_tokens=100, k=5):
        """
        投机解码生成
        k: 每次草稿模型生成的 Token 数量 (lookahead 长度)
        """
        input_ids = self.tokenizer.encode(prompt, return_tensors="pt").to(self.device)
        generated = input_ids

        while generated.shape[1] < max_new_tokens:
            # Step 1: 草稿模型快速生成 k 个 Token
            draft_tokens = []
            draft_input = generated

            for _ in range(k):
                draft_logits = self.draft(draft_input).logits[:, -1, :]
                next_token = torch.argmax(draft_logits, dim=-1, keepdim=True)
                draft_tokens.append(next_token)
                draft_input = torch.cat([draft_input, next_token], dim=1)

            # Step 2: 目标模型并行验证
            # 关键: 一次前向传播计算所有 k 个位置的 logits
            target_logits = self.target(draft_input).logits

            # Step 3: 逐 Token 验证
            accepted = 0
            for i in range(k):
                # 取出第 i 个 Token 位置的目标模型概率
                target_probs = torch.softmax(
                    target_logits[:, -(k-i+1), :], dim=-1
                )
                draft_token_id = draft_tokens[i].item()
                p_target = target_probs[0, draft_token_id].item()

                # 草稿模型的概率
                draft_logits_i = self.draft(generated).logits[:, -1, :]
                draft_probs = torch.softmax(draft_logits_i, dim=-1)
                p_draft = draft_probs[0, draft_token_id].item()

                # 接受准则: min(1, p_target / p_draft)
                gamma = p_target / (p_draft + 1e-10)
                if torch.rand(1).item() < min(1.0, gamma):
                    # 接受
                    generated = torch.cat([generated, draft_tokens[i]], dim=1)
                    accepted += 1
                else:
                    # 拒绝,从目标模型重新采样
                    new_token = torch.multinomial(target_probs, 1)
                    generated = torch.cat([generated, new_token], dim=1)
                    break  # 停止后续验证

            if accepted == 0:
                # 如果一个都没接受,至少生成一个 Token
                target_logits_final = self.target(generated).logits[:, -1, :]
                new_token = torch.argmax(target_logits_final, dim=-1, keepdim=True)
                generated = torch.cat([generated, new_token], dim=1)

        return self.tokenizer.decode(generated[0], skip_special_tokens=True)

# 使用示例
decoder = SpeculativeDecoder(
    draft_model_id="Qwen/Qwen2-0.5B-Instruct",
    target_model_id="Qwen/Qwen2-7B-Instruct"
)

output = decoder.generate("Explain quantum computing in simple terms:")
print(output)

实际加速效果:

  • 理想情况 (Draft 模型准确率高): 2.5-3x 加速。
  • 一般情况: 1.8-2x 加速。
  • 最坏情况 (Draft 模型完全随机): 无加速甚至变慢 (因为验证开销)。

生产环境集成 (vLLM):

python -m vllm.entrypoints.openai.api_server \
    --model Qwen/Qwen2-7B-Instruct \
    --speculative-model Qwen/Qwen2-0.5B-Instruct \
    --num-speculative-tokens 5

3. Medusa:多头并行预测#

论文: Medusa: Simple LLM Inference Acceleration Framework with Multiple Decoding Heads (2024)

Speculative Decoding 需要额外的草稿模型,而 Medusa 只需要在原模型基础上增加几个轻量级的预测头。

3.1 架构:在 LM Head 之上增加多个预测头#

                      ┌─→ Head_1 (预测下1个Token)
Hidden States ────────┼─→ Head_2 (预测下2个Token)
(from last layer)     │
                      ├─→ Head_3 (预测下3个Token)
                      └─→ ... (最多预测 K=5 个)

关键特点:

  • 每个 Medusa Head 都是一个轻量级的 MLP (如 1-2 层,维度 2048)。
  • Head_k 负责预测"未来第 k 个 Token"。
  • 所有 Head 并行输出,一次前向传播得到多个候选。

3.2 训练:自监督蒸馏#

Medusa Head 的训练不需要额外标注,直接从原模型的生成结果中学习。

训练数据构造: 给定序列 [x1, x2, x3, x4, x5],对于位置 x3:

  • Head_1 的标签 = x4
  • Head_2 的标签 = x5
  • Head_3 的标签 = x6 (如果存在)

损失函数: $$ \mathcal{L} = \sum_{k=1}^{K} \text{CE}(\text{Head}k(h_t), x{t+k}) $$

其中 $h_t$ 是位置 $t$ 的隐藏状态,$x_{t+k}$ 是未来第 $k$ 个 Token。

训练流程:

from transformers import AutoModelForCausalLM
import torch.nn as nn

class MedusaHead(nn.Module):
    def __init__(self, hidden_size, vocab_size, num_heads=3):
        super().__init__()
        self.heads = nn.ModuleList([
            nn.Sequential(
                nn.Linear(hidden_size, hidden_size),
                nn.ReLU(),
                nn.Linear(hidden_size, vocab_size)
            ) for _ in range(num_heads)
        ])

    def forward(self, hidden_states):
        """
        hidden_states: [batch, seq, hidden]
        返回: list of [batch, seq, vocab_size]
        """
        return [head(hidden_states) for head in self.heads]

# 训练伪代码
base_model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2-7B")
medusa = MedusaHead(hidden_size=3584, vocab_size=151936, num_heads=4)

# 冻结 base_model,只训练 Medusa Head
for param in base_model.parameters():
    param.requires_grad = False

# ... 标准的监督学习训练循环 ...

训练成本:

  • 数据量: 约 10-50M Tokens (远小于预训练)。
  • 时间: 单卡 A100 训练 1-2 天。
  • 显存: 与原模型推理相当 (因为 base_model 冻结)。

3.3 Tree Attention 优化#

Medusa 的核心挑战:如何高效地验证多个候选路径?

假设每个 Head 输出 Top-2 候选,K=3 个 Head,则可能的路径数 = $2^3 = 8$ 条。

朴素方案: 逐条验证 → 8 次前向传播 → 没有加速

Tree Attention: 将所有候选路径组织成一棵树,用一次前向传播并行验证。

        Root (当前Token)
        /            \
    Cand_1_A      Cand_1_B  (Head 1 的 Top-2)
    /    \         /    \
Cand_2_A Cand_2_B ...   ... (Head 2 的候选)

实现: 使用特殊的 Attention Mask,让每个节点只能看到其祖先:

# 构造 Tree Attention Mask
# 示例: 8 个候选路径
#   Mask[i, j] = 1 表示 Token i 可以 attend to Token j
mask = torch.tensor([
    [1, 0, 0, 0, 0, 0, 0, 0],  # Root
    [1, 1, 0, 0, 0, 0, 0, 0],  # Path 1-A
    [1, 0, 1, 0, 0, 0, 0, 0],  # Path 1-B
    [1, 1, 0, 1, 0, 0, 0, 0],  # Path 1-A -> 2-A
    # ...
])

加速效果:

  • 开源榜单 (MT-Bench): 2.2-2.8x 加速。
  • 代码生成 (HumanEval): 1.5-2x 加速 (因为代码的可预测性更强)。

4. Lookahead Decoding:前瞻解码#

论文: Break the Sequential Dependency of LLM Inference Using Lookahead Decoding (ICLR 2024)

4.1 N-gram 缓存原理#

Lookahead Decoding 利用一个观察:LLM 的输出存在大量重复模式

例如在代码生成中:

for i in range(10):
    print(i)
# 之后很可能再次出现
for j in range(10):
    print(j)

如果我们缓存了 "for i in range(10):\n print(i)" 的生成结果,那么生成 for j 时可以直接复用部分计算。

核心数据结构: N-gram Cache

ngram_cache = {
    ("for", "i", "in"): "range",
    ("range", "(", "10"): ")",
    # ...
}

4.2 Jacobi 迭代并行化#

Lookahead 的第二个技巧:Jacobi Decoding

传统解码:

x_1 = f(x_0)
x_2 = f(x_1)
x_3 = f(x_2)  # 顺序依赖

Jacobi 迭代:

# 并行猜测
x'_1, x'_2, x'_3 = 随机初始化

# 迭代优化
for iter in range(T):
    x'_1 = f(x_0)         # 基于真实前缀
    x'_2 = f(x'_1)        # 基于猜测
    x'_3 = f(x'_2)
    # 如果收敛 → 提前退出

关键洞察: 如果猜测合理,Jacobi 迭代通常 2-3 步即可收敛,而不是线性的 N 步。

4.3 适用场景分析#

Lookahead Decoding 的加速效果高度依赖任务特性:

任务类型加速效果原因
代码生成2-3x高度结构化,重复模式多
数学推导1.8-2.5x公式有固定格式
闲聊对话1.2-1.5x随机性强,难以预测
翻译1.5-2x句式有规律

生产环境启用 (SGLang):

python -m sglang.launch_server \
    --model Qwen/Qwen2-7B-Instruct \
    --enable-lookahead \
    --lookahead-window 4

5. 其他前沿技术#

5.1 Eagle:基于特征的推测#

论文: EAGLE: Speculative Sampling Requires Rethinking Feature Uncertainty (2024)

创新点: 传统 Speculative Decoding 在 Token 空间做推测。EAGLE 在隐藏状态 (Feature) 空间做推测。

优势:

  • 隐藏状态的维度 (如 4096) 远小于词表 (如 100k)。
  • 特征空间更平滑,更容易预测。

架构:

Draft Model: 轻量级 Transformer
输入: 目标模型的最后一层隐藏状态 h_t
输出: 预测未来的隐藏状态 h_{t+1}, h_{t+2}, ...

加速效果: 比传统 Speculative Decoding 再提升 20-30%

5.2 Cascade Speculation:层级推测#

核心思想: 不同的 Token 难度不同。

  • 简单 Token (如标点、连词): 用极小模型 (0.1B) 生成。
  • 中等 Token (如常见名词): 用小模型 (0.5B) 生成。
  • 困难 Token (如专业术语): 用大模型 (7B) 生成。

实现: 构建一个"模型金字塔",从小到大逐级验证。


本章小结#

方法核心思想加速比是否需要额外训练适用场景
Speculative Decoding小模型草稿 + 大模型验证2-3x通用
Medusa多头并行预测2-2.8x是 (轻量)通用
LookaheadN-gram 缓存 + Jacobi 迭代1.5-3x结构化任务
Eagle特征空间推测2.5-3.5x是 (中等)通用

工程建议:

  1. 快速试验: 先用 Speculative Decoding (vLLM 原生支持,无需训练)。
  2. 极致性能: 训练 Medusa Head (1-2 天成本,长期收益高)。
  3. 特定任务: 代码/数学任务开启 Lookahead。

终极组合: Speculative Decoding + Medusa

  • 草稿模型用 Medusa-enhanced 0.5B。
  • 目标模型用标准 7B。
  • 理论加速: 3-5x

下一章预告: 第4章 - 推理模型专题 (DeepSeek-R1 / OpenAI o1)

在下一章,我们将探讨慢推理(Slow Inference)——如何通过增加推理时计算(Test-Time Compute)来换取更高的准确性,这与本章的加速技术形成互补。

[统计组件仅在生产环境显示]