Distributed / multi-GPU training#
To train across several GPUs (or nodes), every process — rank — must see a disjoint slice of the data, otherwise ranks waste work on duplicate samples.
DistributedSampler does exactly that: it wraps any base sampler and restricts each rank to its own contiguous shard of the observations.
It reads the rank and world size from torch.distributed (dist_info="torch"), from JAX (dist_info="jax"), or from a callable you provide.
So the same collection feeds a PyTorch DDP or a JAX pmap job unchanged.
Configure zarrs#
import warnings
import zarr
zarr.config.set({"codec_pipeline.path": "zarrs.ZarrsCodecPipeline"})
for msg in ("Consolidated metadata is currently not part in the Zarr format 3 specification.*",):
warnings.filterwarnings("ignore", message=msg)
A small collection to shard#
Any DatasetCollection works; here we make a small synthetic one (see Single-cell RNA-seq or Genetics (VCF / whole-genome sequencing) for real data).
import anndata as ad
import pandas as pd
import scipy.sparse as sp
from annbatch import DatasetCollection
# AnnData 0.13.0 defaults to the following settings
ad.settings.zarr_write_format = 3
ad.settings.auto_shard_zarr_v3 = True
n_cells, n_genes = 20_000, 256
adata = ad.AnnData(
X=sp.random(n_cells, n_genes, density=0.1, format="csr", dtype="float32", random_state=0),
var=pd.DataFrame(index=[f"gene_{i}" for i in range(n_genes)]),
)
adata.write_zarr("synthetic.zarr")
DatasetCollection(zarr.open("ddp_collection.zarr", mode="w")).add_adatas(
adata_paths=["synthetic.zarr"], shuffle=True, dataset_size=10_000
)
How the shards line up#
DistributedSampler wraps a base sampler (e.g. RandomSampler).
With a callable dist_info, we can inspect how a 2-rank job would split the data without launching any processes.
Each rank yields the same number of complete batches over its own slice.
import anndata as ad
from annbatch import Loader
from annbatch.samplers import DistributedSampler, RandomSampler
def _load_x(group: zarr.Group) -> ad.AnnData:
return ad.AnnData(X=ad.io.sparse_dataset(group["X"]))
collection = DatasetCollection(zarr.open("ddp_collection.zarr", mode="r"))
for rank in range(2):
sampler = DistributedSampler(
RandomSampler(chunk_size=64, preload_nchunks=32, batch_size=256),
dist_info=lambda rank=rank: (rank, 2),
)
loader = Loader(batch_sampler=sampler, preload_to_gpu=False).use_collection(collection, load_adata=_load_x)
print(f"rank {rank} of 2 → {len(loader)} batches/epoch")
rank 0 of 2 → 39 batches/epoch
rank 1 of 2 → 39 batches/epoch
Run it for real on multiple GPUs#
In a real job each rank is a separate process launched with torchrun.
The training script below initialises torch.distributed, builds the DistributedSampler with dist_info="torch", and streams its shard to its own GPU.
%%writefile ddp_train.py
import anndata as ad
import torch
import torch.distributed as dist
import zarr
zarr.config.set({"codec_pipeline.path": "zarrs.ZarrsCodecPipeline"})
from annbatch import DatasetCollection, Loader
from annbatch.samplers import DistributedSampler, RandomSampler
dist.init_process_group("gloo") # production GPU training typically uses "nccl"
rank, world = dist.get_rank(), dist.get_world_size()
torch.cuda.set_device(rank)
collection = DatasetCollection(zarr.open("ddp_collection.zarr", mode="r"))
sampler = DistributedSampler(
RandomSampler(chunk_size=64, preload_nchunks=32, batch_size=256),
dist_info="torch",
)
loader = Loader(batch_sampler=sampler, preload_to_gpu=False).use_collection(
collection, load_adata=lambda g: ad.AnnData(X=ad.io.sparse_dataset(g["X"]))
)
rows = 0
for batch in loader:
x = batch["X"].cuda().to_dense() # densify on this rank's GPU
rows += x.shape[0]
print(f"[rank {rank}/{world}] cuda:{torch.cuda.current_device()} streamed {rows} rows in {len(loader)} batches")
dist.barrier()
dist.destroy_process_group()
!torchrun --nproc_per_node=2 --master_port=29571 ddp_train.py 2>&1 | grep "rank"
Each rank streamed a disjoint half of the collection onto its own GPU. The two shards never overlap, so an epoch covers the data exactly once across all ranks.
For production GPU training, initialise the process group with the
ncclbackend instead ofgloo.For JAX, call
jax.distributed.initialize()and passdist_info="jax".enforce_equal_batches=True(the default) trims each rank to the same number of complete batches so collectives stay in lock-step.
See Advanced: Implementing a Custom Sampler for writing your own sampling strategies.