Single-cell RNA-seq#
Single-cell RNA-seq data is a cells × genes matrix that is mostly zeros and therefore commonly stored as a sparse (CSR) matrix in X.
Loader is designed to yield data from such matrices efficiently.
This guide runs the full pipeline on two real human datasets from CELLxGENE (~170k cells together).
If you are completely new to annbatch, skim the Quickstart first.
Configure zarrs#
The zarrs-python codec pipeline is what makes local reads fast; set it once before touching any store.
import warnings
import zarr
zarr.config.set({"codec_pipeline.path": "zarrs.ZarrsCodecPipeline"})
for msg in (
"Consolidated metadata is currently not part in the Zarr format 3 specification.*",
"Observation names are not unique.*",
):
warnings.filterwarnings("ignore", message=msg)
Get the data#
Download two human datasets from CELLxGENE.
Replace these with your own .h5ad files — everything below is identical.
!test -f human_dataset_1.h5ad || wget -q -O human_dataset_1.h5ad https://datasets.cellxgene.cziscience.com/866d7d5e-436b-4dbd-b7c1-7696487d452e.h5ad
!test -f human_dataset_2.h5ad || wget -q -O human_dataset_2.h5ad https://datasets.cellxgene.cziscience.com/f81463b8-4986-4904-a0ea-20ff02cbb317.h5ad
Both datasets store raw counts as a sparse CSR matrix (in .raw.X for CELLxGENE), annotated with a categorical cell_type in obs.
import anndata as ad
adata_paths = ["human_dataset_1.h5ad", "human_dataset_2.h5ad"]
for path in adata_paths:
a = ad.experimental.read_lazy(path)
print(f"{path}: {a.shape[0]:,} cells × {a.shape[1]:,} genes (dtype {a.X.dtype})")
Convert into a shuffled collection#
add_adatas() reads every file lazily, outer-joins their gene spaces, shuffles cells across all of them, and writes a sharded zarr collection split into datasets of roughly dataset_size cells.
We pass a load_adata that keeps only what we need — the raw counts and the cell_type label.
This avoids loading unused obs columns and keeps every stored matrix to a single dtype.
Uniform dtypes matter for performance; see Preshuffling Performance Considerations.
from annbatch import DatasetCollection
def _to_raw_counts(path: str) -> ad.AnnData:
a = ad.experimental.read_lazy(path)
x, var = (a.raw.X, a.raw.var) if a.raw is not None else (a.X, a.var)
return ad.AnnData(X=x, obs=a.obs[["cell_type"]].to_memory(), var=var.to_memory())
collection = DatasetCollection(zarr.open("scrnaseq_collection.zarr", mode="w"))
collection.add_adatas(
adata_paths=adata_paths,
load_adata=_to_raw_counts,
shuffle=True,
dataset_size=100_000, # cells per on-disk dataset; pick what fits comfortably in RAM
)
print("datasets in collection:", len(list(collection)))
Stream shuffled mini-batches#
The Loader fetches preload_nchunks contiguous chunks of chunk_size cells, shuffles them together, and yields batches of batch_size.
As a rule of thumb keep preload_nchunks around 32 so that chunk_size * preload_nchunks stays well above batch_size — the larger that shuffle pool, the better the mixing.
Each batch is a sparse X; densify it on the GPU for speed.
from annbatch import Loader
def _load_adata(group: zarr.Group) -> ad.AnnData:
return ad.AnnData(
X=ad.io.sparse_dataset(group["X"]),
obs=ad.experimental.read_lazy(group).obs[["cell_type"]].to_memory(),
)
loader = Loader(
batch_size=4096,
chunk_size=512,
preload_nchunks=32,
preload_to_gpu=True,
).use_collection(collection, load_adata=_load_adata)
batch = next(iter(loader))
x = batch["X"].cuda().to_dense()
print(f"batch X: {tuple(x.shape)} {x.dtype} on {x.device}")
print(f"distinct cell types in this batch: {batch['obs']['cell_type'].nunique()}")
batch X: (4096, 36406) torch.float32 on cuda:0
distinct cell types in this batch: 19
Throughput#
A full pass over the collection, densifying every batch on the GPU.
import time
n_cells = loader.n_obs
start = time.perf_counter()
for batch in loader:
_ = batch["X"].cuda().to_dense()
elapsed = time.perf_counter() - start
print(f"streamed {n_cells:,} cells in {elapsed:.1f}s → {n_cells / elapsed:,.0f} cells/s")
streamed 171,792 cells in 1.1s → 151,953 cells/s
Extend a collection with new data#
When new samples arrive you do not need to re-shuffle everything.
Calling add_adatas() again shuffles the new cells into the existing datasets and subsets them to the collection’s gene space.
Here we re-add one file to show the call; in practice you pass genuinely new files.
collection.add_adatas(adata_paths=["human_dataset_1.h5ad"], load_adata=_to_raw_counts)
print("cells after extending:", collection.obs(columns=["cell_type"]).shape[0])
Scaling up#
To scale up, pass more files to add_adatas and raise dataset_size to the largest chunk that fits comfortably in RAM.
For class-balanced batches (e.g. to oversample rare cell types), write the collection with groupby="cell_type" and use a class-aware sampler — see Advanced: Implementing a Custom Sampler.
For perturbation data, ClassSampler lets you sample equally across conditions.
To shard a collection across multiple GPUs, see Distributed / multi-GPU training.