🎲

Speculative Decoding

投机解码 / 猜测解码

📖 什么是投机解码?

Speculative Decoding(投机解码)是一种LLM推理加速技术:用一个小模型快速生成草稿,然后用大模型并行验证,接受正确的部分,拒绝错误的部分。

核心优势:不损失任何质量的前提下,实现2-3倍加速。被Google、Meta、DeepSeek等广泛采用。

💡 一句话理解

📝 比喻

「实习生写草稿,主管审核」

实习生(小模型)快速写好初稿,主管(大模型)审核。如果草稿正确,直接通过;如果有错,主管亲自改。因为实习生写得快,主管只看不改,效率就上去了。

⚙️ 工作流程

🚀 小模型快速生成:["今", "天", "天气", "很", "好"]
🔍 大模型并行验证:验证这5个token的概率
✅ 接受:["今", "天", "天气", "很"] (4个正确)
❌ 拒绝:"好" → 大模型生成正确的:"不错"
🎉 最终输出:"今天天气很不错"

关键点:大模型一次前向传播就能验证多个token,而不是逐个生成。

📊 为什么能加速?

# 传统解码
for i in range(n_tokens):
    大模型前向传播()  # n次大模型调用

# 投机解码
草稿 = 小模型快速生成(k个token)  # 小模型快10倍
接受数 = 大模型一次验证(草稿)    # 1次大模型调用

# 如果接受率 = 70%,k = 5
# 平均每次大模型调用产出:0.7 × 5 = 3.5个token
# 加速比 ≈ 3.5x(理论上限)

💻 代码示例

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

# 加载大模型(目标模型)和小模型(草稿模型)
target_model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2-72B")
draft_model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2-7B")
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-72B")

def speculative_decode(prompt, max_tokens=100, k=5):
    input_ids = tokenizer.encode(prompt, return_tensors="pt")
    
    generated = input_ids.clone()
    
    for _ in range(max_tokens // k):
        # 1. 小模型快速生成k个token
        draft_output = draft_model.generate(
            generated, 
            max_new_tokens=k,
            do_sample=True
        )
        draft_tokens = draft_output[0, generated.shape[1]:]
        
        # 2. 大模型验证(并行)
        with torch.no_grad():
            target_logits = target_model(draft_output).logits
        
        # 3. 逐个验证并决定接受/拒绝
        accepted = 0
        for i, token in enumerate(draft_tokens):
            prob = torch.softmax(target_logits[0, generated.shape[1] + i - 1], dim=-1)
            if torch.rand() < prob[token]:  # 接受
                accepted += 1
            else:  # 拒绝,从大模型分布采样
                new_token = torch.multinomial(prob, 1)
                generated = torch.cat([generated, new_token.unsqueeze(0)], dim=1)
                break
        
        if accepted > 0:
            generated = torch.cat([generated, draft_tokens[:accepted].unsqueeze(0)], dim=1)
    
    return tokenizer.decode(generated[0])

🎯 最佳实践

📊 草稿模型选择

小模型×4-10倍参数

🔢 猜测长度k

通常4-8,平衡效率

📈 接受率优化

同系列模型接受率更高

适用场景

批量推理、长文本生成

🏛️ 实际应用

  • Google: Palm、Gemini使用投机解码
  • DeepSeek: DeepSeek-V3支持投机解码
  • Medusa: 多头投机解码架构
  • vLLM: 生产级投机解码支持

📖 相关导航

← 返回术语百科 | 首页 | 文章 | 专题