GRPO/RL Training with TRL
Expert-level guidance for implementing Group Relative Policy Optimization (GRPO) using the Transformer Reinforcement Learning (TRL) library. This skill provides battle-tested patterns, critical insights, and production-ready workflows for fine-tuning language models with custom reward functions.
When to Use This Skill
Use GRPO training when you need to:
- Enforce specific output formats (e.g., XML tags, JSON, structured reasoning)
- Teach verifiable tasks with objective correctness metrics (math, coding, fact-checking)
- Improve reasoning capabilities by rewarding chain-of-thought patterns
- Align models to domain-specific behaviors without labeled preference data
- Optimize for multiple objectives simultaneously (format + correctness + style)
Do NOT use GRPO for:
- Simple supervised fine-tuning tasks (use SFT instead)
- Tasks without clear reward signals
- When you already have high-quality preference pairs (use DPO/PPO instead)
Core Concepts
1. GRPO Algorithm Fundamentals
Key Mechanism:
- Generates multiple completions for each prompt (group size: 4-16)
- Compares completions within each group using reward functions
- Updates policy to favor higher-rewarded responses relative to the group
Critical Difference from PPO:
- No separate reward model needed
- More sample-efficient (learns from within-group comparisons)
- Simpler to implement and debug
Mathematical Intuition:
For each prompt p:
1. Generate N completions: {c₁, c₂, ..., cₙ}
2. Compute rewards: {r₁, r₂, ..., rₙ}
3. Learn to increase probability of high-reward completions
relative to low-reward ones in the same group
2. Reward Function Design Philosophy
Golden Rules:
- Compose multiple reward functions - Each handles one aspect (format, correctness, style)
- Scale rewards appropriately - Higher weight = stronger signal
- Use incremental rewards - Partial credit for partial compliance
- Test rewards independently - Debug each reward function in isolation
Reward Function Types:
| Type | Use Case | Example Weight |
|---|---|---|
| Correctness | Verifiable tasks (math, code) | 2.0 (highest) |
| Format | Strict structure enforcement | 0.5-1.0 |
| Length | Encourage verbosity/conciseness | 0.1-0.5 |
| Style | Penalize unwanted patterns | -0.5 to 0.5 |
Implementation Workflow
Step 1: Dataset Preparation
Critical Requirements:
- Prompts in chat format (list of dicts with 'role' and 'content')
- Include system prompts to set expectations
- For verifiable tasks, include ground truth answers as additional columns
Example Structure:
from datasets import load_dataset, Dataset
SYSTEM_PROMPT = """
Respond in the following format:
<reasoning>
[Your step-by-step thinking]
</reasoning>
<answer>
[Final answer]
</answer>
"""
def prepare_dataset(raw_data):
"""
Transform raw data into GRPO-compatible format.
Returns: Dataset with columns:
- 'prompt': List[Dict] with role/content (system + user messages)
- 'answer': str (ground truth, optional but recommended)
"""
return raw_data.map(lambda x: {
'prompt': [
{'role': 'system', 'content': SYSTEM_PROMPT},
{'role': 'user', 'content': x['question']}
],
'answer': extract_answer(x['raw_answer'])
})
Pro Tips:
- Use one-shot or few-shot examples in system prompt for complex formats
- Keep prompts concise (max_prompt_length: 256-512 tokens)
- Validate data quality before training (garbage in = garbage out)
Step 2: Reward Function Implementation
Template Structure:
def reward_function_name(
prompts, # List[List[Dict]]: Original prompts
completions, # List[List[Dict]]: Model generations
answer=None, # Optional: Ground truth from dataset
**kwargs # Additional dataset columns
) -> list[float]:
"""
Evaluate completions and return rewards.
Returns: List of floats (one per completion)
"""
# Extract completion text
responses = [comp[0]['content'] for comp in completions]
# Compute rewards
rewards = []
for response in responses:
score = compute_score(response)
rewards.append(score)
return rewards
Example 1: Correctness Reward (Math/Coding)
def correctness_reward(prompts, completions, answer, **kwargs):
"""Reward correct answers with high score."""
responses = [comp[0]['content'] for comp in completions]
extracted = [extract_final_answer(r) for r in responses]
return [2.0 if ans == gt else 0.0
for ans, gt in zip(extracted, answer)]
Example 2: Format Reward (Structured Output)
import re
def format_reward(completions, **kwargs):
"""Reward XML-like structured format."""
pattern = r'<reasoning>.*?</reasoning>\s*<answer>.*?</answer>'
responses = [comp[0]['content'] for comp in completions]
return [1.0 if re.search(pattern, r, re.DOTALL) else 0.0
for r in responses]
Example 3: Incremental Format Reward (Partial Credit)
def incremental_format_reward(completions, **kwargs):
"""Award partial credit for format compliance."""
responses = [comp[0]['content'] for comp in completions]
rewards = []
for r in responses:
score = 0.0
if '<reasoning>' in r:
score += 0.25
if '</reasoning>' in r:
score += 0.25
if '<answer>' in r:
score += 0.25
if '</answer>' in r:
score += 0.25
# Penalize extra text after closing tag
if r.count('</answer>') == 1:
extra_text = r.split('</answer>')[-1].strip()
score -= len(extra_text) * 0.001
rewards.append(score)
return rewards
Critical Insight: Combine 3-5 reward functions for robust training. Order matters less than diversity of signals.
Step 3: Training Configuration
Memory-Optimized Config (Small GPU)
from trl import GRPOConfig
training_args = GRPOConfig(
output_dir="outputs/grpo-model",
# Learning rate
learning_rate=5e-6, # Lower = more stable
adam_beta1=0.9,
adam_beta2=0.99,
weight_decay=0.1,
warmup_ratio=0.1,
lr_scheduler_type='cosine',
# Batch settings
per_device_train_batch_size=1,
gradient_accumulation_steps=4, # Effective batch = 4
# GRPO-specific
num_generations=8, # Group size: 8-16 recommended
max_prompt_length=256,
max_completion_length=512,
# Training duration
num_train_epochs=1,
max_steps=None, # Or set fixed steps (e.g., 500)
# Optimization
bf16=True, # Faster on A100/H100
optim="adamw_8bit", # Memory-efficient optimizer
max_grad_norm=0.1,
# Logging
logging_steps=1,
save_steps=100,
report_to="wandb", # Or "none" for no logging
)
High-Performance Config (Large GPU)
training_args = GRPOConfig(
output_dir="outputs/grpo-model",
learning_rate=1e-5,
per_device_train_batch_size=4,
gradient_accumulation_steps=2,
num_generations=16, # Larger groups = better signal
max_prompt_length=512,
max_completion_length=1024,
num_train_epochs=1,
bf16=True,
use_vllm=True, # Fast generation with vLLM
logging_steps=10,
)
Critical Hyperparameters:
| Parameter | Impact | Tuning Advice |
|---|---|---|
num_generations | Group size for comparison | Start with 8, increase to 16 if GPU allows |
learning_rate | Convergence speed/stability | 5e-6 (safe), 1e-5 (faster, riskier) |
max_completion_length | Output verbosity | Match your task (512 for reasoning, 256 for short answers) |
gradient_accumulation_steps | Effective batch size | Increase if GPU memory |