Single-cell RNA-seq#

Single-cell RNA-seq data is a cells × genes matrix that is mostly zeros and therefore commonly stored as a sparse (CSR) matrix in X. Loader is designed to yield data from such matrices efficiently.

This guide runs the full pipeline on two real human datasets from CELLxGENE (~170k cells together).

If you are completely new to annbatch, skim the Quickstart first.

Configure zarrs#

The zarrs-python codec pipeline is what makes local reads fast; set it once before touching any store.

import warnings

import zarr

zarr.config.set({"codec_pipeline.path": "zarrs.ZarrsCodecPipeline"})
for msg in (
    "Consolidated metadata is currently not part in the Zarr format 3 specification.*",
    "Observation names are not unique.*",
):
    warnings.filterwarnings("ignore", message=msg)

Get the data#

Download two human datasets from CELLxGENE. Replace these with your own .h5ad files — everything below is identical.

!test -f human_dataset_1.h5ad || wget -q -O human_dataset_1.h5ad https://datasets.cellxgene.cziscience.com/866d7d5e-436b-4dbd-b7c1-7696487d452e.h5ad
!test -f human_dataset_2.h5ad || wget -q -O human_dataset_2.h5ad https://datasets.cellxgene.cziscience.com/f81463b8-4986-4904-a0ea-20ff02cbb317.h5ad

Both datasets store raw counts as a sparse CSR matrix (in .raw.X for CELLxGENE), annotated with a categorical cell_type in obs.

import anndata as ad

adata_paths = ["human_dataset_1.h5ad", "human_dataset_2.h5ad"]
for path in adata_paths:
    a = ad.experimental.read_lazy(path)
    print(f"{path}: {a.shape[0]:,} cells × {a.shape[1]:,} genes (dtype {a.X.dtype})")

Hide code cell output

human_dataset_1.h5ad: 99,457 cells × 35,475 genes (dtype float32)
human_dataset_2.h5ad: 72,335 cells × 36,406 genes (dtype float32)

Convert into a shuffled collection#

add_adatas() reads every file lazily, outer-joins their gene spaces, shuffles cells across all of them, and writes a sharded zarr collection split into datasets of roughly dataset_size cells.

We pass a load_adata that keeps only what we need — the raw counts and the cell_type label. This avoids loading unused obs columns and keeps every stored matrix to a single dtype. Uniform dtypes matter for performance; see Preshuffling Performance Considerations.

from annbatch import DatasetCollection


def _to_raw_counts(path: str) -> ad.AnnData:
    a = ad.experimental.read_lazy(path)
    x, var = (a.raw.X, a.raw.var) if a.raw is not None else (a.X, a.var)
    return ad.AnnData(X=x, obs=a.obs[["cell_type"]].to_memory(), var=var.to_memory())


collection = DatasetCollection(zarr.open("scrnaseq_collection.zarr", mode="w"))
collection.add_adatas(
    adata_paths=adata_paths,
    load_adata=_to_raw_counts,
    shuffle=True,
    dataset_size=100_000,  # cells per on-disk dataset; pick what fits comfortably in RAM
)
print("datasets in collection:", len(list(collection)))

Hide code cell output

datasets in collection: 2

Stream shuffled mini-batches#

The Loader fetches preload_nchunks contiguous chunks of chunk_size cells, shuffles them together, and yields batches of batch_size. As a rule of thumb keep preload_nchunks around 32 so that chunk_size * preload_nchunks stays well above batch_size — the larger that shuffle pool, the better the mixing. Each batch is a sparse X; densify it on the GPU for speed.

from annbatch import Loader


def _load_adata(group: zarr.Group) -> ad.AnnData:
    return ad.AnnData(
        X=ad.io.sparse_dataset(group["X"]),
        obs=ad.experimental.read_lazy(group).obs[["cell_type"]].to_memory(),
    )


loader = Loader(
    batch_size=4096,
    chunk_size=512,
    preload_nchunks=32,
    preload_to_gpu=True,
).use_collection(collection, load_adata=_load_adata)

batch = next(iter(loader))
x = batch["X"].cuda().to_dense()
print(f"batch X: {tuple(x.shape)} {x.dtype} on {x.device}")
print(f"distinct cell types in this batch: {batch['obs']['cell_type'].nunique()}")
batch X: (4096, 36406) torch.float32 on cuda:0
distinct cell types in this batch: 19

Throughput#

A full pass over the collection, densifying every batch on the GPU.

import time

n_cells = loader.n_obs
start = time.perf_counter()
for batch in loader:
    _ = batch["X"].cuda().to_dense()
elapsed = time.perf_counter() - start
print(f"streamed {n_cells:,} cells in {elapsed:.1f}s → {n_cells / elapsed:,.0f} cells/s")
streamed 171,792 cells in 1.1s → 151,953 cells/s

Extend a collection with new data#

When new samples arrive you do not need to re-shuffle everything. Calling add_adatas() again shuffles the new cells into the existing datasets and subsets them to the collection’s gene space. Here we re-add one file to show the call; in practice you pass genuinely new files.

collection.add_adatas(adata_paths=["human_dataset_1.h5ad"], load_adata=_to_raw_counts)
print("cells after extending:", collection.obs(columns=["cell_type"]).shape[0])

Hide code cell output

cells after extending: 271249

Scaling up#

To scale up, pass more files to add_adatas and raise dataset_size to the largest chunk that fits comfortably in RAM. For class-balanced batches (e.g. to oversample rare cell types), write the collection with groupby="cell_type" and use a class-aware sampler — see Advanced: Implementing a Custom Sampler. For perturbation data, ClassSampler lets you sample equally across conditions. To shard a collection across multiple GPUs, see Distributed / multi-GPU training.