Speculative Decoding: Accelerating LLM Inference
When to Use This Skill
Use Speculative Decoding when you need to:
- Speed up inference by 1.5-3.6× without quality loss
- Reduce latency for real-time applications (chatbots, code generation)
- Optimize throughput for high-volume serving
- Deploy efficiently on limited hardware
- Generate faster without changing model architecture
Key Techniques: Draft model speculative decoding, Medusa (multiple heads), Lookahead Decoding (Jacobi iteration)
Papers: Medusa (arXiv 2401.10774), Lookahead Decoding (ICML 2024), Speculative Decoding Survey (ACL 2024)
Installation
# Standard speculative decoding (transformers)
pip install transformers accelerate
# Medusa (multiple decoding heads)
git clone https://github.com/FasterDecoding/Medusa
cd Medusa
pip install -e .
# Lookahead Decoding
git clone https://github.com/hao-ai-lab/LookaheadDecoding
cd LookaheadDecoding
pip install -e .
# Optional: vLLM with speculative decoding
pip install vllm
Quick Start
Basic Speculative Decoding (Draft Model)
from transformers import AutoModelForCausalLM, AutoTokenizer
# Load target model (large, slow)
target_model = AutoModelForCausalLM.from_pretrained(
"meta-llama/Llama-2-70b-hf",
device_map="auto",
torch_dtype=torch.float16
)
# Load draft model (small, fast)
draft_model = AutoModelForCausalLM.from_pretrained(
"meta-llama/Llama-2-7b-hf",
device_map="auto",
torch_dtype=torch.float16
)
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-70b-hf")
# Generate with speculative decoding
prompt = "Explain quantum computing in simple terms:"
inputs = tokenizer(prompt, return_tensors="pt").to("cuda")
# Transformers 4.36+ supports assisted generation
outputs = target_model.generate(
**inputs,
assistant_model=draft_model, # Enable speculative decoding
max_new_tokens=256,
do_sample=True,
temperature=0.7,
)
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
print(response)
Medusa (Multiple Decoding Heads)
from medusa.model.medusa_model import MedusaModel
# Load Medusa-enhanced model
model = MedusaModel.from_pretrained(
"FasterDecoding/medusa-vicuna-7b-v1.3", # Pre-trained with Medusa heads
torch_dtype=torch.float16,
device_map="auto"
)
tokenizer = AutoTokenizer.from_pretrained("FasterDecoding/medusa-vicuna-7b-v1.3")
# Generate with Medusa (2-3× speedup)
prompt = "Write a Python function to calculate fibonacci numbers:"
inputs = tokenizer(prompt, return_tensors="pt").to("cuda")
outputs = model.medusa_generate(
**inputs,
max_new_tokens=256,
temperature=0.7,
posterior_threshold=0.09, # Acceptance threshold
posterior_alpha=0.3, # Tree construction parameter
)
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
Lookahead Decoding (Jacobi Iteration)
from lookahead.lookahead_decoding import LookaheadDecoding
# Load model
model = AutoModelForCausalLM.from_pretrained(
"meta-llama/Llama-2-7b-hf",
torch_dtype=torch.float16,
device_map="auto"
)
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")
# Initialize lookahead decoding
lookahead = LookaheadDecoding(
model=model,
tokenizer=tokenizer,
window_size=15, # Lookahead window (W)
ngram_size=5, # N-gram size (N)
guess_size=5 # Number of parallel guesses
)
# Generate (1.5-2.3× speedup)
prompt = "Implement quicksort in Python:"
output = lookahead.generate(prompt, max_new_tokens=256)
print(output)
Core Concepts
1. Speculative Decoding (Draft Model)
Idea: Use small draft model to generate candidates, large target model to verify in parallel.
Algorithm:
- Draft model generates K tokens speculatively
- Target model evaluates all K tokens in parallel (single forward pass)
- Accept tokens where draft and target agree
- Reject first disagreement, continue from there
def speculative_decode(target_model, draft_model, prompt, K=4):
"""Speculative decoding algorithm."""
# 1. Generate K draft tokens
draft_tokens = draft_model.generate(prompt, max_new_tokens=K)
# 2. Target model evaluates all K tokens in one forward pass
target_logits = target_model(draft_tokens) # Parallel!
# 3. Accept/reject based on probability match
accepted = []
for i in range(K):
p_draft = softmax(draft_model.logits[i])
p_target = softmax(target_logits[i])
# Acceptance probability
if random.random() < min(1, p_target[draft_tokens[i]] / p_draft[draft_tokens[i]]):
accepted.append(draft_tokens[i])
else:
break # Reject, resample from target
return accepted
Performance:
- Speedup: 1.5-2× with good draft model
- Zero quality loss (mathematically equivalent to target model)
- Best when draft model is 5-10× smaller than target
2. Medusa (Multiple Decoding Heads)
Source: arXiv 2401.10774 (2024)
Innovation: Add multiple prediction heads to existing model, predict future tokens without separate draft model.
Architecture:
Input → Base LLM (frozen) → Hidden State
├→ Head 1 (predicts token t+1)
├→ Head 2 (predicts token t+2)
├→ Head 3 (predicts token t+3)
└→ Head 4 (predicts token t+4)
Training:
- Medusa-1: Freeze base LLM, train only heads
- 2.2× speedup, lossless
- Medusa-2: Fine-tune base LLM + heads together
- 2.3-3.6× speedup, better quality
Tree-based Attention:
# Medusa constructs tree of candidates
# Example: Predict 2 steps ahead with top-2 per step
# Root
# / \
# T1a T1b (Step 1: 2 candidates)
# / \ / \
# T2a T2b T2c T2d (Step 2: 4 candidates total)
# Single forward pass evaluates entire tree!
Advantages:
- No separate draft model needed
- Minimal training (only heads)
- Compatible with any LLM
3. Lookahead Decoding (Jacobi Iteration)
Source: ICML 2024
Core idea: Reformulate autoregressive decoding as solving system of equations, solve in parallel using Jacobi iteration.
Mathematical formulation:
Traditional: y_t = f(x, y_1, ..., y_{t-1}) (sequential)
Jacobi: y_t^{(k+1)} = f(x, y_1^{(k)}, ..., y_{t-1}^{(k)}) (parallel)
Two branches:
-
Lookahead Branch: Generate n-grams in parallel
- Window size W: How many steps to look ahead
- N-gram size N: How many past tokens to use
-
Verification Branch: Verify promising n-grams
- Match n-grams with generated tokens
- Accept if first token matches
class LookaheadDecoding:
def __init__(self, model, window_size=15, ngram_size=5):
self.model = model
self.W = window_size # Lookahead window
self.N = ngram_size # N-gram size
def generate_step(self, tokens):
# Lookahead branch: Generate W × N candidates
candidates = {}
for w in range(1, self.W + 1):
for n in range(1, self.N + 1):
# Generate n-gram starting at position w
ngram = self.generate_ngram(tokens, start=w, length=n)
candidates[(w, n)] = ngram
# Verification branch: Find matching n-grams
verified = []
for ngram in candidates.values():
if ngram[0] == tokens[-1]: # First token matches last input
if self.verify(tokens, ngram):
verified.append(ngram)
# Accept longest verified n-gram
return max(verified, key=len) if verified else [self.model.generate_next(tokens)]
Performance:
- Speedup: 1.5-2.3× (up to 3.6× for code generation)
- No draft model or training needed
- Works out-of-the-box with any model
Method Comparison
| Method | Speedup | Training Needed | Draft Model | Quality L