Distributed / multi-GPU training#

To train across several GPUs (or nodes), every process — rank — must see a disjoint slice of the data, otherwise ranks waste work on duplicate samples. DistributedSampler does exactly that: it wraps any base sampler and restricts each rank to its own contiguous shard of the observations.

It reads the rank and world size from torch.distributed (dist_info="torch"), from JAX (dist_info="jax"), or from a callable you provide. So the same collection feeds a PyTorch DDP or a JAX pmap job unchanged.

Configure zarrs#

import warnings

import zarr

zarr.config.set({"codec_pipeline.path": "zarrs.ZarrsCodecPipeline"})
for msg in ("Consolidated metadata is currently not part in the Zarr format 3 specification.*",):
    warnings.filterwarnings("ignore", message=msg)

A small collection to shard#

Any DatasetCollection works; here we make a small synthetic one (see Single-cell RNA-seq or Genetics (VCF / whole-genome sequencing) for real data).

import anndata as ad
import pandas as pd
import scipy.sparse as sp

from annbatch import DatasetCollection

# AnnData 0.13.0 defaults to the following settings
ad.settings.zarr_write_format = 3
ad.settings.auto_shard_zarr_v3 = True

n_cells, n_genes = 20_000, 256
adata = ad.AnnData(
    X=sp.random(n_cells, n_genes, density=0.1, format="csr", dtype="float32", random_state=0),
    var=pd.DataFrame(index=[f"gene_{i}" for i in range(n_genes)]),
)
adata.write_zarr("synthetic.zarr")

DatasetCollection(zarr.open("ddp_collection.zarr", mode="w")).add_adatas(
    adata_paths=["synthetic.zarr"], shuffle=True, dataset_size=10_000
)

Hide code cell output

<annbatch.io.DatasetCollection at 0x755f67e0e510>

How the shards line up#

DistributedSampler wraps a base sampler (e.g. RandomSampler). With a callable dist_info, we can inspect how a 2-rank job would split the data without launching any processes. Each rank yields the same number of complete batches over its own slice.

import anndata as ad

from annbatch import Loader
from annbatch.samplers import DistributedSampler, RandomSampler


def _load_x(group: zarr.Group) -> ad.AnnData:
    return ad.AnnData(X=ad.io.sparse_dataset(group["X"]))


collection = DatasetCollection(zarr.open("ddp_collection.zarr", mode="r"))
for rank in range(2):
    sampler = DistributedSampler(
        RandomSampler(chunk_size=64, preload_nchunks=32, batch_size=256),
        dist_info=lambda rank=rank: (rank, 2),
    )
    loader = Loader(batch_sampler=sampler, preload_to_gpu=False).use_collection(collection, load_adata=_load_x)
    print(f"rank {rank} of 2 → {len(loader)} batches/epoch")
rank 0 of 2 → 39 batches/epoch
rank 1 of 2 → 39 batches/epoch

Run it for real on multiple GPUs#

In a real job each rank is a separate process launched with torchrun. The training script below initialises torch.distributed, builds the DistributedSampler with dist_info="torch", and streams its shard to its own GPU.

%%writefile ddp_train.py
import anndata as ad
import torch
import torch.distributed as dist
import zarr

zarr.config.set({"codec_pipeline.path": "zarrs.ZarrsCodecPipeline"})

from annbatch import DatasetCollection, Loader
from annbatch.samplers import DistributedSampler, RandomSampler

dist.init_process_group("gloo")  # production GPU training typically uses "nccl"
rank, world = dist.get_rank(), dist.get_world_size()
torch.cuda.set_device(rank)

collection = DatasetCollection(zarr.open("ddp_collection.zarr", mode="r"))
sampler = DistributedSampler(
    RandomSampler(chunk_size=64, preload_nchunks=32, batch_size=256),
    dist_info="torch",
)
loader = Loader(batch_sampler=sampler, preload_to_gpu=False).use_collection(
    collection, load_adata=lambda g: ad.AnnData(X=ad.io.sparse_dataset(g["X"]))
)

rows = 0
for batch in loader:
    x = batch["X"].cuda().to_dense()  # densify on this rank's GPU
    rows += x.shape[0]
print(f"[rank {rank}/{world}] cuda:{torch.cuda.current_device()} streamed {rows} rows in {len(loader)} batches")
dist.barrier()
dist.destroy_process_group()

Hide code cell output

Writing ddp_train.py
!torchrun --nproc_per_node=2 --master_port=29571 ddp_train.py 2>&1 | grep "rank"

Hide code cell output

[rank 0/2] cuda:0 streamed 9984 rows in 39 batches
[rank 1/2] cuda:1 streamed 9984 rows in 39 batches

Each rank streamed a disjoint half of the collection onto its own GPU. The two shards never overlap, so an epoch covers the data exactly once across all ranks.

  • For production GPU training, initialise the process group with the nccl backend instead of gloo.

  • For JAX, call jax.distributed.initialize() and pass dist_info="jax".

  • enforce_equal_batches=True (the default) trims each rank to the same number of complete batches so collectives stay in lock-step.

See Advanced: Implementing a Custom Sampler for writing your own sampling strategies.