torchdrug
Overview
TorchDrug is a comprehensive machine learning framework for drug discovery built on PyTorch. It provides graph-based molecular representations (atoms as nodes, bonds as edges), a library of graph neural network (GNN) architectures, benchmark datasets, and pretrained models for tasks including molecular property prediction, drug-target interaction, retrosynthesis, and generative molecular design. TorchDrug integrates with PyTorch Lightning and standard ML tooling, making it accessible to both computational chemists and ML practitioners.
When to Use
- Molecular property prediction: Training or fine-tuning GNN models to predict ADMET properties (solubility, toxicity, permeability) or bioactivity (IC50, Ki) from molecular graphs.
- Drug-target interaction (DTI) prediction: Building models that predict binding affinity between a compound (SMILES) and a protein (sequence or structure).
- Retrosynthesis prediction: Identifying plausible synthetic routes for a target molecule using template-based or template-free models.
- Pretraining on large molecular datasets: Leveraging pretrained GNN representations on ChEMBL or ZINC for transfer learning to small datasets.
- Molecular generation: Training graph-based generative models (GCPN, GraphAF) to design novel molecules with desired properties.
- Benchmarking GNN architectures: Comparing GraphConv, MPNN, GAT, AttentiveFP on standard MoleculeNet tasks.
- For fast fingerprint-based property prediction without deep learning, use RDKit + scikit-learn instead.
- For protein structure tasks (folding, docking), use ESMFold or DiffDock rather than TorchDrug.
Prerequisites
- Python packages:
torchdrug,torch,torch-geometric,rdkit - Environment: Python 3.8+, CUDA-compatible GPU recommended for training
- Data requirements: SMILES strings or molecular SDF files; protein sequences for DTI tasks
pip install torch torchvision --extra-index-url https://download.pytorch.org/whl/cu118
pip install torch-geometric
pip install torchdrug
pip install rdkit
Quick Start
import torch
from torchdrug import data, datasets, models, tasks, core
# Load a benchmark dataset and train a GNN for property prediction
dataset = datasets.BBBP("~/data/bbbp", node_feature="default", edge_feature="default")
print(f"Dataset: {len(dataset)} molecules, task: BBBP (blood-brain barrier penetration)")
# Define model: GIN encoder
model = models.GIN(
input_dim=dataset.node_feature_dim,
hidden_dims=[256, 256],
short_cut=True,
batch_norm=True,
concat_hidden=True,
)
# Define training task
task = tasks.PropertyPrediction(
model, task=dataset.tasks,
criterion="bce", metric=("auprc", "auroc"),
)
# Train with the Solver
optimizer = torch.optim.Adam(task.parameters(), lr=1e-3)
solver = core.Engine(task, dataset, None, None, optimizer, gpus=[0])
solver.train(num_epoch=50)
print("Training complete")
Core API
Module 1: Molecular Graph Representation
TorchDrug represents molecules as typed graphs. data.Molecule is the core data structure.
from torchdrug import data
from rdkit import Chem
# Create a molecule from SMILES
smiles = "CC(=O)Oc1ccccc1C(=O)O" # aspirin
mol = data.Molecule.from_smiles(smiles, node_feature="default", edge_feature="default")
print(f"Atoms: {mol.num_node}")
print(f"Bonds: {mol.num_edge}")
print(f"Node feature dim: {mol.node_feature.shape}") # (N_atoms, feature_dim)
print(f"Edge feature dim: {mol.edge_feature.shape}") # (N_bonds*2, feature_dim)
# Convert a MoleculeNet / custom SMILES list to a dataset
from torchdrug import data as td_data
import pandas as pd
df = pd.read_csv("compounds.csv") # columns: smiles, label
molecules = [td_data.Molecule.from_smiles(s) for s in df["smiles"] if s]
print(f"Loaded {len(molecules)} valid molecules")
# Check feature dimensions
print(f"Default atom feature dim: {molecules[0].node_feature.shape[1]}")
Module 2: GNN Architectures
TorchDrug provides GIN, RGCN, GraphSAGE, GAT, MPNN, AttentiveFP, and more.
from torchdrug import models, datasets
dataset = datasets.ESOL("~/data/esol", node_feature="default", edge_feature="default")
feature_dim = dataset.node_feature_dim
# Graph Isomorphism Network (GIN) — good default for property prediction
gin = models.GIN(
input_dim=feature_dim,
hidden_dims=[256, 256, 256],
short_cut=True,
batch_norm=True,
concat_hidden=True, # concatenate layer representations
)
print(f"GIN output_dim: {gin.output_dim}")
from torchdrug import models
# Message Passing Neural Network (MPNN) — captures edge features
mpnn = models.MPNN(
input_dim=feature_dim,
hidden_dim=256,
edge_input_dim=16, # edge feature dimension
num_layer=4,
num_gru_layer=1,
)
# Graph Attention Network (GAT) — attention-weighted neighbors
gat = models.GAT(
input_dim=feature_dim,
hidden_dims=[256, 256],
edge_input_dim=16,
num_head=8,
batch_norm=True,
)
print(f"MPNN output_dim: {mpnn.output_dim}, GAT output_dim: {gat.output_dim}")
Module 3: Molecular Property Prediction
Wrap a GNN encoder with a prediction head for classification or regression.
import torch
from torchdrug import datasets, models, tasks, core
# Regression example: ESOL aqueous solubility
dataset = datasets.ESOL("~/data/esol", node_feature="default", edge_feature="default")
train, val, test = dataset.split()
print(f"Train: {len(train)}, Val: {len(val)}, Test: {len(test)}")
model = models.GIN(
input_dim=dataset.node_feature_dim,
hidden_dims=[300, 300],
short_cut=True,
batch_norm=True,
concat_hidden=True,
)
task = tasks.PropertyPrediction(
model,
task=dataset.tasks, # list of property names
criterion="mse", # "mse" for regression, "bce" for classification
metric=("mae", "rmse"),
num_mlp_layer=2,
)
optimizer = torch.optim.Adam(task.parameters(), lr=1e-3, weight_decay=1e-5)
solver = core.Engine(task, train, val, test, optimizer,
batch_size=32, log_interval=50)
solver.train(num_epoch=100)
# Evaluate on test set
metrics = solver.evaluate("test")
print(f"Test RMSE: {metrics['rmse']:.4f}")
print(f"Test MAE: {metrics['mae']:.4f}")
Module 4: Drug-Target Interaction (DTI) Prediction
Predict binding affinity between molecules and protein sequences.
from torchdrug import datasets, models, tasks, core
import torch
# Load a DTI dataset (e.g., Davis kinase binding affinities)
dataset = datasets.Davis("~/data/davis",
mol_node_feature="default",
mol_edge_feature="default")
train, val, test = dataset.split()
# Molecule encoder
mol_model = models.GIN(
input_dim=dataset.mol_node_feature_dim,
hidden_dims=[256, 256],
short_cut=True,
batch_norm=True,
concat_hidden=True,
)
# Protein encoder (CNN on sequence)
prot_model = models.ProteinCNN(
input_dim=21, # amino acid vocabulary size
hidden_dims=[128, 128, 128],
kernel_size=3,
)
task = tasks.InteractionPrediction(
mol_model, prot_model,
task=dataset.tasks,
criterion="mse",
metric=("rmse", "pearsonr"),
)
optimizer = torch.optim.Adam(task.parameters(), lr=1e-3)
solver = core.Engine(task, train, val, test, optimizer,
batch_size=64, log_interval=100)
solver.train(num_epoch=50)
metrics = solver.evaluate("test")
print(f"DTI Test RMSE: {metrics['rmse']:.4f}")
print(f"DTI Pearson r: {metrics['pearsonr']:.4f}")
Module 5: Retrosynthesis Prediction
Predict one-step retrosynthetic disconnections to find plausible building blocks.
from torchdrug import datasets, models, tasks, core
import torch
# USPTO-50k retrosynthesis benchmark
dataset = datasets.USPTO50k("~/data/uspto50k",
as_synthon=False,
atom_feature=