scikit-survival -- Survival Analysis
Overview
scikit-survival is a Python library for time-to-event analysis built on scikit-learn. It handles right-censored data (observations where the event has not yet occurred) using Cox models, ensemble methods, survival SVMs, and non-parametric estimators. All models follow the scikit-learn fit/predict API and integrate with Pipelines, cross-validation, and GridSearchCV.
When to Use
- Modeling time-to-event outcomes with right-censored data (clinical trials, reliability)
- Fitting Cox proportional hazards models (standard or elastic net penalized)
- Building ensemble survival models (Random Survival Forest, Gradient Boosting)
- Training survival SVMs for margin-based learning on medium-sized datasets
- Evaluating survival predictions with censoring-aware metrics (C-index, Brier score, AUC)
- Estimating non-parametric survival curves (Kaplan-Meier, Nelson-Aalen)
- Analyzing competing risks with cumulative incidence functions
- High-dimensional survival data with automatic feature selection (CoxNet L1/L2)
- For simpler parametric models (Weibull, log-normal AFT) or statistical tests (log-rank), use
lifelines - For deep learning survival models, use
pycoxortorchlife
Prerequisites
pip install scikit-survival scikit-learn pandas numpy matplotlib
Python: >= 3.9. Dependencies: scikit-learn, numpy, scipy, pandas, joblib, osqp (for some SVM solvers).
Data format: Survival outcomes are NumPy structured arrays with (event, time) fields. Events are boolean (True = event occurred, False = censored). Times are positive floats.
Quick Start
from sksurv.datasets import load_breast_cancer
from sksurv.ensemble import RandomSurvivalForest
from sksurv.metrics import concordance_index_ipcw
from sklearn.model_selection import train_test_split
X, y = load_breast_cancer()
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
rsf = RandomSurvivalForest(n_estimators=100, random_state=42)
rsf.fit(X_train, y_train)
risk_scores = rsf.predict(X_test)
c_index = concordance_index_ipcw(y_train, y_test, risk_scores)[0]
print(f"C-index: {c_index:.3f}") # e.g., 0.68
# Individual survival curves
surv_fns = rsf.predict_survival_function(X_test[:2])
for fn in surv_fns:
print(f"5-year survival: {fn(365 * 5):.3f}")
Core API
Module 1: Data Preparation
Create structured survival arrays and preprocess features.
import numpy as np
import pandas as pd
from sksurv.util import Surv
from sksurv.preprocessing import OneHotEncoder, encode_categorical
from sksurv.datasets import load_gbsg2, load_breast_cancer
from sklearn.preprocessing import StandardScaler
# Create survival outcome from arrays
event = np.array([True, False, True, True, False])
time = np.array([120.0, 365.0, 200.0, 90.0, 400.0])
y = Surv.from_arrays(event=event, time=time)
print(y.dtype) # [('event', '?'), ('time', '<f8')]
# From DataFrame columns
# y = Surv.from_dataframe("event_col", "time_col", df)
# Load built-in datasets
# Available: load_gbsg2, load_breast_cancer, load_veterans_lung_cancer,
# load_whas500, load_aids, load_flchain
X, y = load_gbsg2()
print(f"Shape: {X.shape}, Events: {y['event'].sum()}, "
f"Censoring rate: {1 - y['event'].mean():.1%}")
# Encode categoricals (survival-aware one-hot)
X_encoded = encode_categorical(X) # auto-detect and encode all categorical cols
# Standardize (critical for Cox and SVM models)
scaler = StandardScaler()
X_scaled = scaler.fit_transform(X_encoded)
from sksurv.io import loadarff
# Load ARFF format (Weka format)
data = loadarff("survival_data.arff")
X_arff, y_arff = data[0], data[1] # DataFrame, structured array
Module 2: Cox Proportional Hazards
Semi-parametric model: h(t|x) = h_0(t) * exp(beta^T x). Interpretable coefficients as log hazard ratios.
from sksurv.linear_model import CoxPHSurvivalAnalysis, CoxnetSurvivalAnalysis, IPCRidge
# Standard Cox PH model
cox = CoxPHSurvivalAnalysis(alpha=0.0, ties="breslow")
cox.fit(X_train, y_train)
print(f"Coefficients: {cox.coef_}") # log hazard ratios
# Hazard ratio interpretation: exp(coef) = HR for 1-unit increase
risk_scores = cox.predict(X_test) # Higher = higher risk
# Survival function for individual patients
surv_funcs = cox.predict_survival_function(X_test[:3])
for fn in surv_funcs:
print(f"5-year survival: {fn(365 * 5):.3f}")
# Penalized Cox (elastic net) -- for high-dimensional data (p > n)
coxnet = CoxnetSurvivalAnalysis(
l1_ratio=0.9, # 0=Ridge, 1=Lasso, between=Elastic Net
alpha_min_ratio=0.01, # smallest alpha / largest alpha ratio
n_alphas=100, # steps in regularization path
)
coxnet.fit(X_train, y_train)
# Feature selection: non-zero coefficients
selected = np.where(coxnet.coef_ != 0)[0]
print(f"Selected {len(selected)} / {X_train.shape[1]} features")
# IPCRidge: accelerated failure time model (predicts log survival time)
ipcridge = IPCRidge(alpha=1.0)
ipcridge.fit(X_train, y_train)
log_survival_time = ipcridge.predict(X_test)
Module 3: Ensemble Methods
Non-parametric tree-based models for complex non-linear relationships.
from sksurv.ensemble import (
RandomSurvivalForest,
GradientBoostingSurvivalAnalysis,
ComponentwiseGradientBoostingSurvivalAnalysis,
ExtraSurvivalTrees,
)
# Random Survival Forest -- robust, minimal tuning
rsf = RandomSurvivalForest(
n_estimators=200, min_samples_split=10, min_samples_leaf=15,
max_features="sqrt", random_state=42, n_jobs=-1,
)
rsf.fit(X_train, y_train)
risk = rsf.predict(X_test)
# Gradient Boosting -- best performance, needs tuning
gbs = GradientBoostingSurvivalAnalysis(
loss="coxph", # "coxph" or "ipcwls" (AFT)
n_estimators=300, learning_rate=0.05, max_depth=3,
subsample=0.8, dropout_rate=0.1, random_state=42,
)
gbs.fit(X_train, y_train)
# ComponentwiseGB -- linear model with automatic feature selection
cgbs = ComponentwiseGradientBoostingSurvivalAnalysis(
n_estimators=100, learning_rate=0.1,
)
cgbs.fit(X_train, y_train)
print(f"Non-zero coefficients: {np.sum(cgbs.coef_ != 0)}")
# ExtraSurvivalTrees -- more regularized than RSF, faster training
est = ExtraSurvivalTrees(n_estimators=100, random_state=42)
est.fit(X_train, y_train)
# Survival curves from any ensemble model
surv_funcs = rsf.predict_survival_function(X_test[:1])
chf_funcs = rsf.predict_cumulative_hazard_function(X_test[:1])
Module 4: Survival SVMs
Margin-based learning for survival ranking. Always standardize features.
from sksurv.svm import FastSurvivalSVM, FastKernelSurvivalSVM, HingeLossSurvivalSVM
# Linear SVM -- fast, for linear relationships
lsvm = FastSurvivalSVM(alpha=1.0, rank_ratio=1.0, max_iter=100, random_state=42)
lsvm.fit(X_train_scaled, y_train)
risk = lsvm.predict(X_test_scaled)
# Kernel SVM -- for non-linear relationships (rbf, poly, sigmoid)
ksvm = FastKernelSurvivalSVM(
alpha=1.0, kernel="rbf", gamma="scale",
max_iter=50, random_state=42,
)
ksvm.fit(X_train_scaled, y_train)
# Hinge loss variant
hsvm = HingeLossSurvivalSVM(alpha=1.0, random_state=42)
hsvm.fit(X_train_scaled, y_train)
# NaiveSurvivalSVM also available but slower (O(n^3))
from sksurv.kernels import ClinicalKernelTransform
# Clinical kernel: combines clinical + molecular features
# Weighs clinical variables separately from high-dimensional molecular data
transform = ClinicalKernelTransform(fit_once=True)
transform.prepare(X_train) # auto-detect clinical features
X_kern = transform.fit_transform(X_train)
Module 5: Non-Parametric Estimation
Estimate survival and hazard curves without model assumptions.
from sksurv.nonparametric import kaplan_meier_estimator, nelson_aalen_estimator
import matplotlib.pyplot as plt
# Kaplan-Meier survival curve
time_km, surv_prob = kaplan_meier_estimator(y["event"], y["