annbatch.samplers.ClassSampler#

class annbatch.samplers.ClassSampler(chunk_size, preload_nchunks, batch_size, *, classes, num_samples, class_weights=None, mask=None, drop_last=False, rng=None)#

Sample class-coherent batches with replacement.

Every batch the Loader yields is drawn from a single class: a class is drawn c ~ Categorical(p) (p proportional to class_weights, uniform by default), then the batch’s observations are drawn from c. A load request may span several classes but no batch mixes them, which makes over- or under-sampling specific populations straightforward.

Sampling is with replacement – each pass draws num_samples observations rather than partitioning a fixed epoch – so there is no notion of an epoch and the number of iterations is fixed. The only size requirement is that chunk_size * preload_nchunks is divisible by batch_size (already enforced by the loader).

Class selection. A class with a non-positive weight is excluded: it is never sampled and its runs are exempt from the run-length rule below. Set a weight to 0 to drop a class; there is no separate exclusion argument.

Run-length rule. Every contiguous run of a non-excluded class must span at least chunk_size observations; otherwise no aligned slice fits inside it and the sampler raises at construction, naming the offending classes by their label.

Mask. Assigning mask restricts sampling to a contiguous observation range [start, stop). The RLE is rebuilt over that window (slice starts stay in global coordinates) and cached on the resolved (start, stop) pair, so reassigning the same mask is free. Class weights are renormalized from the original values over only the classes present in the new range; if no class with a positive weight remains, the assignment raises.

Multiple workers are not supported with this sampler.

Implementation#

A run-length encoding (RLE) of classes.codes is built over the mask range. A class boundary may only fall where a chunk edge and a batch edge coincide, which happens every lcm(chunk_size, batch_size) rows; so classes are assigned per group of group_chunks = batch_size // gcd(chunk_size, batch_size) chunks (one lcm block) – one class c ~ Categorical(p) is drawn independently for each group and shared across its chunks. Each chunk is then a single-class on-disk read and each batch falls inside one group, hence one class. Drawing per minimal group packs as many classes into a window as coherence allows: up to preload_nchunks // group_chunks distinct classes (equivalently preload_nchunks * chunk_size // lcm(chunk_size, batch_size)). preload_nchunks is always a multiple of group_chunks because chunk_size * preload_nchunks is divisible by batch_size, so groups tile each window. A uniform chunk-start within c is drawn per chunk (a prefix-sum lookup maps it to the absolute slice in O(log n_runs)), and rows within each batch are shuffled, so batches are class-coherent but not ordered. Memory scales with the number of runs (<= n_obs // chunk_size).

Examples

>>> from annbatch import Loader
>>> from annbatch.samplers import ClassSampler
>>> # Get categorical column from collection
>>> classes = collection.obs(columns=["categories"])["categories"].values
>>> sampler = ClassSampler(
...     chunk_size=10,
...     preload_nchunks=4,
...     batch_size=10,
...     classes=classes,
...     num_samples=1000,
... )
>>> loader = Loader(batch_sampler=sampler).use_collection(collection)
type chunk_size:

int

param chunk_size:

Number of observations in each slice yielded. Also the minimum run length required of every non-excluded class (see the run-length rule).

type preload_nchunks:

int

param preload_nchunks:

Number of chunks to load per iteration.

type batch_size:

int

param batch_size:

Number of observations per batch. chunk_size * preload_nchunks must be divisible by it; it need not divide or be a multiple of chunk_size.

type classes:

Categorical

param classes:

A pandas.Categorical with one entry per observation, e.g. df["cell_type"].values when the column already has a categorical dtype. If loading categories from a DatasetCollection, they can be retrieved via collection.obs(columns=["cell_type"])["cell_type"].values (if the column was stored with categorical dtype) or converted using pd.Categorical(collection.obs(columns=["label"])["label"]) (if stored as integers or strings). Length must equal the loader’s n_obs. The obs axis need not be contiguous per class, but every run of a non-excluded class must be at least chunk_size long (see the run-length rule above). NA values (codes == -1) are not allowed.

type num_samples:

int

param num_samples:

Total number of observations to draw.

type class_weights:

ndarray | None (default: None)

param class_weights:

Optional weights, one per class in classes.categories (so len(class_weights) == len(classes.categories)), controlling how often each class is drawn. A non-positive weight excludes that class entirely. When None (the default) every class is drawn uniformly. For proportional (≈ plain global random) sampling pass each class’s observation count. The weights are kept and, whenever a mask narrows the range, the weights of the classes still present are renormalized.

type mask:

slice | None (default: None)

param mask:

Optional contiguous observation range to restrict sampling to. Defaults to the whole dataset.

type drop_last:

bool (default: False)

param drop_last:

Whether to drop the last incomplete batch.

type rng:

Generator | None (default: None)

param rng:

Random number generator. Note that torch.manual_seed() has no effect here; pass a seeded numpy.random.Generator to control randomness.

Attributes table#

batch_size

The batch size for data loading.

mask

The observation range this sampler operates on.

rng

The random number generator used by this sampler.

shuffle

Whether data is shuffled.

Methods table#

n_batches(n_obs)

Return the number of batches.

n_iters(n_obs)

Return the number of batches.

sample(n_obs)

Sample load requests given the total number of observations.

validate(n_obs)

Validate that the codes describe exactly the loader's observations.

Attributes#

ClassSampler.batch_size#
ClassSampler.mask#
ClassSampler.rng#

The random number generator used by this sampler.

ClassSampler.shuffle#

Methods#

ClassSampler.n_batches(n_obs)#

Return the number of batches.

Parameters:
n_obs int

The total number of observations available.

Return type:

int

Returns:

int The total number of batches this sampler will produce.

ClassSampler.n_iters(n_obs)#

Return the number of batches.

Deprecated since version 0.2.0: Use n_batches() instead.

Return type:

int

ClassSampler.sample(n_obs)#

Sample load requests given the total number of observations.

Base implementation simply calls validate() and then yields via _sample().

Parameters:
n_obs int

The total number of observations available.

Yields:

LoadRequest – Load requests for batching data.

Return type:

Iterator[LoadRequest]

ClassSampler.validate(n_obs)#

Validate that the codes describe exactly the loader’s observations.

Return type:

None