Long Context: Extending Transformer Context Windows
When to Use This Skill
Use Long Context techniques when you need to:
- Process long documents (32k, 64k, 128k+ tokens) with transformer models
- Extend context windows of pre-trained models (LLaMA, Mistral, etc.)
- Implement efficient positional encodings (RoPE, ALiBi)
- Train models with length extrapolation capabilities
- Deploy models that handle variable-length inputs efficiently
- Fine-tune existing models for longer contexts with minimal compute
Key Techniques: RoPE (Rotary Position Embeddings), YaRN, ALiBi (Attention with Linear Biases), Position Interpolation
Papers: RoFormer (arXiv 2104.09864), YaRN (arXiv 2309.00071), ALiBi (arXiv 2108.12409), Position Interpolation (arXiv 2306.15595)
Installation
# HuggingFace Transformers (includes RoPE, YaRN support)
pip install transformers torch
# For custom implementations
pip install einops # Tensor operations
pip install rotary-embedding-torch # Standalone RoPE
# Optional: FlashAttention for efficiency
pip install flash-attn --no-build-isolation
Quick Start
RoPE (Rotary Position Embeddings)
import torch
import torch.nn as nn
class RotaryEmbedding(nn.Module):
"""Rotary Position Embeddings (RoPE)."""
def __init__(self, dim, max_seq_len=8192, base=10000):
super().__init__()
# Compute inverse frequencies
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
self.register_buffer("inv_freq", inv_freq)
self.max_seq_len = max_seq_len
def forward(self, seq_len, device):
# Position indices
t = torch.arange(seq_len, device=device).type_as(self.inv_freq)
# Compute frequencies
freqs = torch.outer(t, self.inv_freq) # (seq_len, dim/2)
# Compute sin and cos
emb = torch.cat((freqs, freqs), dim=-1) # (seq_len, dim)
return emb.cos(), emb.sin()
def rotate_half(x):
"""Rotate half the hidden dimensions."""
x1, x2 = x.chunk(2, dim=-1)
return torch.cat((-x2, x1), dim=-1)
def apply_rotary_pos_emb(q, k, cos, sin):
"""Apply rotary embeddings to queries and keys."""
# q, k shape: (batch, heads, seq_len, dim)
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed
# Usage
rope = RotaryEmbedding(dim=64, max_seq_len=8192)
cos, sin = rope(seq_len=2048, device='cuda')
# In attention layer
q_rotated, k_rotated = apply_rotary_pos_emb(query, key, cos, sin)
ALiBi (Attention with Linear Biases)
def get_alibi_slopes(num_heads):
"""Get ALiBi slope values for each attention head."""
def get_slopes_power_of_2(n):
start = 2 ** (-(2 ** -(math.log2(n) - 3)))
ratio = start
return [start * (ratio ** i) for i in range(n)]
if math.log2(num_heads).is_integer():
return get_slopes_power_of_2(num_heads)
else:
# Closest power of 2
closest_power = 2 ** math.floor(math.log2(num_heads))
slopes = get_slopes_power_of_2(closest_power)
# Add extra slopes
extra = get_slopes_power_of_2(2 * closest_power)
slopes.extend(extra[0::2][:num_heads - closest_power])
return slopes
def create_alibi_bias(seq_len, num_heads):
"""Create ALiBi attention bias."""
# Distance matrix
context_position = torch.arange(seq_len)
memory_position = torch.arange(seq_len)
relative_position = memory_position[None, :] - context_position[:, None]
# Get slopes
slopes = torch.tensor(get_alibi_slopes(num_heads))
# Apply slopes to distances
alibi = slopes[:, None, None] * relative_position[None, :, :]
return alibi # (num_heads, seq_len, seq_len)
# Usage in attention
num_heads = 8
seq_len = 2048
alibi_bias = create_alibi_bias(seq_len, num_heads).to('cuda')
# Add bias to attention scores
# attn_scores shape: (batch, num_heads, seq_len, seq_len)
attn_scores = attn_scores + alibi_bias
attn_weights = torch.softmax(attn_scores, dim=-1)
Position Interpolation for LLaMA
from transformers import LlamaForCausalLM, LlamaTokenizer
# Original context: 2048 tokens
model = LlamaForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf")
# Extend to 32k with position interpolation
# Modify RoPE base frequency
model.config.rope_scaling = {
"type": "linear",
"factor": 16.0 # 2048 * 16 = 32768
}
# Or use dynamic scaling
model.config.rope_scaling = {
"type": "dynamic",
"factor": 16.0
}
# Fine-tune with long documents (minimal steps needed)
# Position interpolation works out-of-the-box after this config change
Core Concepts
1. RoPE (Rotary Position Embeddings)
How it works:
- Encodes absolute position via rotation matrix
- Provides relative position dependency in attention
- Enables length extrapolation
Mathematical formulation:
q_m = (W_q * x_m) * e^(imθ)
k_n = (W_k * x_n) * e^(inθ)
where θ_j = base^(-2j/d) for j ∈ [0, d/2)
Advantages:
- Decaying inter-token dependency with distance
- Compatible with linear attention
- Better extrapolation than absolute position encodings
2. YaRN (Yet another RoPE extensioN)
Key innovation:
- NTK-aware interpolation (Neural Tangent Kernel)
- Attention temperature scaling
- Efficient context extension (10× less tokens vs baselines)
Parameters:
# YaRN configuration
yarn_config = {
"scale": 16, # Extension factor
"original_max_position": 2048, # Base context
"extrapolation_factor": 1.0, # NTK parameter
"attn_factor": 1.0, # Attention scaling
"beta_fast": 32, # High-frequency scale
"beta_slow": 1, # Low-frequency scale
}
Performance:
- Extends LLaMA to 128k tokens
- 2.5× less training steps than baselines
- State-of-the-art context window extension
3. ALiBi (Attention with Linear Biases)
Core idea:
- No positional embeddings added to tokens
- Apply distance penalty directly to attention scores
- Bias proportional to key-query distance
Formula:
attention_bias[i, j] = -m * |i - j|
where m = slope for each attention head
Advantages:
- 11% faster training vs sinusoidal embeddings
- 11% less memory usage
- Strong length extrapolation (train 1k, test 2k+)
- Inductive bias towards recency
4. Position Interpolation
Technique:
- Linearly down-scale position indices
- Interpolate within trained range (vs extrapolate beyond)
- Minimal fine-tuning required
Formula:
# Original: position indices [0, 1, 2, ..., L]
# Extended: position indices [0, 0.5, 1.0, ..., L/2]
# (for 2× extension)
scaled_position[i] = i / extension_factor
Results:
- LLaMA 7B-65B extended to 32k tokens
- 1000 fine-tuning steps sufficient
- 600× better stability than extrapolation
Method Comparison
| Method | Max Context | Training Needed | Memory | Extrapolation | Best For |
|---|---|---|---|---|---|
| RoPE | 8k-32k | Full pre-training | Moderate | Good | New models |
| YaRN | 32k-128k | Minimal (10× efficient) | Moderate | Excellent | Extending existing models |
| ALiBi | Unlimited | Full pre-training | Low (-11%) | Excellent | Training from scratch |
| Position Interpolation | 32k+ | Minimal (1k steps) | Moderate | Poor (by design) | Quick extension |
Implementation Patterns
HuggingFace Transformers Integration
from transformers import AutoModelForCausalLM, AutoConfig
# RoPE with YaRN scaling
config = AutoConfig.from_pretrained("mistralai/Mistral-7B-v0.1")
config.rope_scaling = {
"type": "yarn",
"factor": 8.0,
"original_max_position_embeddings": 8192,
"attention_factor": 1.0
}
model = AutoModelForCausalLM.from_config(config)
# Position interpolation (simpler)
config.rope_scaling = {
"type": "linear",
"factor": 4.0
}
# Dynamic scaling (a