OpenPI Fine-Tuning and Serving
End-to-end workflows for fine-tuning and serving Physical Intelligence's OpenPI models (pi0, pi0-fast, pi0.5) on robot manipulation tasks from the public openpi repository. Covers blank-machine setup, JAX training, PyTorch training, checkpoint conversion, and policy inference serving.
Quick start
Clone the public repo, install the workspace, then serve a pretrained policy:
git clone --recurse-submodules https://github.com/Physical-Intelligence/openpi.git
cd openpi
GIT_LFS_SKIP_SMUDGE=1 uv sync
GIT_LFS_SKIP_SMUDGE=1 uv pip install -e .
uv run scripts/serve_policy.py --env DROID
from openpi_client import websocket_client_policy
client = websocket_client_policy.WebsocketClientPolicy(host="localhost", port=8000)
result = client.infer(observation)
actions = result["actions"] # numpy array of shape (chunk_size, action_dim)
Core concepts
Model family: OpenPI implements three model variants from Physical Intelligence:
| Model | Architecture | Speed | Quality | Typical use |
|---|---|---|---|---|
| pi0 | Flow-matching VLA | Baseline | Highest | Research, complex tasks |
| pi0-fast | Autoregressive action tokens | 2-5x faster | Good | Real-time control |
| pi0.5 | pi0 + improved vision encoder | Baseline | Best | Latest default |
Key design choices:
- Dual backend: JAX (primary, official training) and PyTorch (community, deployment-friendly)
- Config-driven: All training/serving parameters defined in
src/openpi/training/config.py - Norm stats: Every config requires precomputed normalization statistics before training
- WebSocket serving: Policy servers expose a WebSocket API for low-latency inference
Training loop invariant: After every config or dataset change, always re-run this cycle:
- Compute norm stats → 2. Train → 3. Serve checkpoint → 4. Validate inference
Compute requirements
| Task | GPU | VRAM | Notes |
|---|---|---|---|
| Serve pi0.5 (inference) | 1x A100/H100 | ~24 GB | Single GPU sufficient |
| Fine-tune pi0.5 (JAX) | 1x A100 80GB | ~60 GB | Use fsdp_devices for multi-GPU |
| Fine-tune pi0 (JAX) | 1x A100 80GB | ~40 GB | Smaller model footprint |
| Fine-tune (PyTorch DDP) | 1-8x A100 | ~40 GB/GPU | torchrun launcher |
| Compute norm stats | CPU or 1x GPU | ~8 GB | Fast, can run on login node |
Workflow 0: Blank-machine setup
Copy this checklist and track progress:
Setup Progress:
- [ ] Step 1: Clone the public openpi repo with submodules
- [ ] Step 2: Install uv and sync the workspace
- [ ] Step 3: Install the editable package
- [ ] Step 4: Verify core imports and serving entrypoint
Step 1: Clone repo
git clone --recurse-submodules https://github.com/Physical-Intelligence/openpi.git
cd openpi
If you already cloned without submodules:
git submodule update --init --recursive
Step 2: Sync dependencies
GIT_LFS_SKIP_SMUDGE=1 uv sync
Step 3: Install editable package
GIT_LFS_SKIP_SMUDGE=1 uv pip install -e .
Step 4: Verify installation
uv run python -c "from openpi.training import config as _config; print(_config.get_config('pi05_droid').name)"
uv run scripts/serve_policy.py --help
When to use vs alternatives
Use this skill when:
- Fine-tuning pi0, pi0-fast, or pi0.5 on LeRobot or RLDS datasets
- Serving OpenPI policies for ALOHA, DROID, or LIBERO evaluation
- Converting JAX checkpoints to PyTorch format
- Debugging OpenPI training issues (norm stats, memory, config)
Use fine-tuning-openvla-oft instead when:
- Fine-tuning OpenVLA with continuous action heads and LoRA
- Reproducing OpenVLA-OFT paper results on LIBERO or ALOHA
Use evaluating-cosmos-policy instead when:
- Evaluating NVIDIA Cosmos Policy on simulation benchmarks
Workflow 1: JAX fine-tuning on LeRobot data
Copy this checklist and track progress:
JAX Fine-Tuning Progress:
- [ ] Step 1: Select and copy closest training config
- [ ] Step 2: Update dataset mapping and base checkpoint
- [ ] Step 3: Compute normalization statistics
- [ ] Step 4: Launch JAX training
- [ ] Step 5: Serve checkpoint and run inference sanity check
Step 1: Select config
Copy the closest config from src/openpi/training/config.py:
| Config | Use case |
|---|---|
pi05_libero | pi0.5 LIBERO fine-tuning |
pi0_libero | pi0 full fine-tuning on LIBERO |
pi0_fast_libero | pi0-fast on LIBERO |
pi0_aloha_pen_uncap | ALOHA custom data |
pi05_droid_finetune | Small custom DROID dataset (LeRobot format) |
pi05_full_droid_finetune | Full DROID RLDS large-scale training |
Step 2: Update dataset and transforms
# In src/openpi/training/config.py, modify your config:
TrainConfig(
name="my_custom_config",
model_type="pi05",
data=LeRobotDataConfig(
repo_id="your-org/your-dataset",
# Adjust transforms to match your data format
),
weight_loader=Pi05WeightLoader(), # Match model type
)
Set repo_id for your dataset and ensure weight_loader matches the model type (pi0 vs pi0.5).
Step 3: Compute normalization statistics
uv run scripts/compute_norm_stats.py --config-name <config_name>
This must run before every training launch when config, dataset, or transforms change.
Step 4: Launch JAX training
XLA_PYTHON_CLIENT_MEM_FRACTION=0.9 uv run scripts/train.py <config_name> \
--exp-name=<run_name> \
--overwrite
For full DROID RLDS training, add the rlds dependency group:
uv run --group rlds scripts/compute_norm_stats.py \
--config-name pi05_full_droid_finetune \
--max-frames 10000000
XLA_PYTHON_CLIENT_MEM_FRACTION=0.9 uv run --group rlds scripts/train.py \
pi05_full_droid_finetune \
--exp-name=<run_name> --overwrite
Step 5: Serve and validate
uv run scripts/serve_policy.py policy:checkpoint \
--policy.config=<config_name> \
--policy.dir=checkpoints/<config_name>/<run_name>/<step>
Verify with a test client:
from openpi_client import websocket_client_policy
client = websocket_client_policy.WebsocketClientPolicy(host="localhost", port=8000)
# Build observation matching your config's expected keys
obs = {"image": img_array, "state": state_array, "prompt": "pick up the cup"}
result = client.infer(obs)
print(f"Action shape: {result['actions'].shape}") # (chunk_size, action_dim)
Workflow 2: PyTorch training and checkpoint conversion
Copy this checklist and track progress:
PyTorch Setup Progress:
- [ ] Step 1: Sync dependencies and verify transformer version
- [ ] Step 2: Apply OpenPI transformer patches
- [ ] Step 3: Convert JAX checkpoint to PyTorch format
- [ ] Step 4: Launch PyTorch training or serve converted checkpoint
Step 1: Sync dependencies
uv sync
uv pip show transformers
Step 2: Apply required patches
OpenPI PyTorch requires custom modifications to the installed transformers package:
cp -r ./src/openpi/models_pytorch/transformers_replace/* \
.venv/lib/python3.11/site-packages/transformers/
Step 3: Convert JAX checkpoint
uv run examples/convert_jax_model_to_pytorch.py \
--checkpoint_dir <jax_checkpoint_dir> \
--config_name <config_name> \
--output_path <pytorch_checkpoint_dir>
Step 4: Train or serve
Single GPU training:
uv run scripts/train_pytorch.py <config_name> --exp_name <run_name>
Multi-GPU distributed training:
uv run torchrun --standalone --nnodes=1 --nproc_per_node=<num_gpus> \
scripts/train_pytorch.py <config_name> --exp_name <run_name>
Programmatic inference with converted checkpoint:
from openpi.training import config as _config
from openpi.policies import policy_config
config = _config.get_config("pi05_droid")
policy = policy_config.create_trained_policy(config, "<pytorch_checkpoint_di