annbatch.samplers.ClassSampler#
- class annbatch.samplers.ClassSampler(chunk_size, preload_nchunks, batch_size, *, classes, num_samples, class_weights=None, mask=None, drop_last=False, rng=None)#
Sample class-coherent batches with replacement.
Every batch the
Loaderyields is drawn from a single class: a class is drawnc ~ Categorical(p)(pproportional toclass_weights, uniform by default), then the batch’s observations are drawn fromc. A load request may span several classes but no batch mixes them, which makes over- or under-sampling specific populations straightforward.Sampling is with replacement – each pass draws
num_samplesobservations rather than partitioning a fixed epoch – so there is no notion of an epoch and the number of iterations is fixed. The only size requirement is thatchunk_size * preload_nchunksis divisible bybatch_size(already enforced by the loader).Class selection. A class with a non-positive weight is excluded: it is never sampled and its runs are exempt from the run-length rule below. Set a weight to
0to drop a class; there is no separate exclusion argument.Run-length rule. Every contiguous run of a non-excluded class must span at least
chunk_sizeobservations; otherwise no aligned slice fits inside it and the sampler raises at construction, naming the offending classes by their label.Mask. Assigning
maskrestricts sampling to a contiguous observation range[start, stop). The RLE is rebuilt over that window (slice starts stay in global coordinates) and cached on the resolved(start, stop)pair, so reassigning the same mask is free. Class weights are renormalized from the original values over only the classes present in the new range; if no class with a positive weight remains, the assignment raises.Multiple workers are not supported with this sampler.
Implementation#
A run-length encoding (RLE) of
classes.codesis built over themaskrange. A class boundary may only fall where a chunk edge and a batch edge coincide, which happens everylcm(chunk_size, batch_size)rows; so classes are assigned per group ofgroup_chunks = batch_size // gcd(chunk_size, batch_size)chunks (onelcmblock) – one classc ~ Categorical(p)is drawn independently for each group and shared across its chunks. Each chunk is then a single-class on-disk read and each batch falls inside one group, hence one class. Drawing per minimal group packs as many classes into a window as coherence allows: up topreload_nchunks // group_chunksdistinct classes (equivalentlypreload_nchunks * chunk_size // lcm(chunk_size, batch_size)).preload_nchunksis always a multiple ofgroup_chunksbecausechunk_size * preload_nchunksis divisible bybatch_size, so groups tile each window. A uniform chunk-start withincis drawn per chunk (a prefix-sum lookup maps it to the absolute slice in O(log n_runs)), and rows within each batch are shuffled, so batches are class-coherent but not ordered. Memory scales with the number of runs (<= n_obs // chunk_size).Examples
>>> from annbatch import Loader >>> from annbatch.samplers import ClassSampler >>> # Get categorical column from collection >>> classes = collection.obs(columns=["categories"])["categories"].values >>> sampler = ClassSampler( ... chunk_size=10, ... preload_nchunks=4, ... batch_size=10, ... classes=classes, ... num_samples=1000, ... ) >>> loader = Loader(batch_sampler=sampler).use_collection(collection)
- type chunk_size:
- param chunk_size:
Number of observations in each slice yielded. Also the minimum run length required of every non-excluded class (see the run-length rule).
- type preload_nchunks:
- param preload_nchunks:
Number of chunks to load per iteration.
- type batch_size:
- param batch_size:
Number of observations per batch.
chunk_size * preload_nchunksmust be divisible by it; it need not divide or be a multiple ofchunk_size.- type classes:
- param classes:
A
pandas.Categoricalwith one entry per observation, e.g.df["cell_type"].valueswhen the column already has a categorical dtype. If loading categories from aDatasetCollection, they can be retrieved viacollection.obs(columns=["cell_type"])["cell_type"].values(if the column was stored with categorical dtype) or converted usingpd.Categorical(collection.obs(columns=["label"])["label"])(if stored as integers or strings). Length must equal the loader’sn_obs. The obs axis need not be contiguous per class, but every run of a non-excluded class must be at leastchunk_sizelong (see the run-length rule above). NA values (codes == -1) are not allowed.- type num_samples:
- param num_samples:
Total number of observations to draw.
- type class_weights:
- param class_weights:
Optional weights, one per class in
classes.categories(solen(class_weights) == len(classes.categories)), controlling how often each class is drawn. A non-positive weight excludes that class entirely. WhenNone(the default) every class is drawn uniformly. For proportional (≈ plain global random) sampling pass each class’s observation count. The weights are kept and, whenever a mask narrows the range, the weights of the classes still present are renormalized.- type mask:
- param mask:
Optional contiguous observation range to restrict sampling to. Defaults to the whole dataset.
- type drop_last:
bool(default:False)- param drop_last:
Whether to drop the last incomplete batch.
- type rng:
- param rng:
Random number generator. Note that
torch.manual_seed()has no effect here; pass a seedednumpy.random.Generatorto control randomness.
Attributes table#
The batch size for data loading. |
|
The observation range this sampler operates on. |
|
The random number generator used by this sampler. |
|
Whether data is shuffled. |
Methods table#
Attributes#
- ClassSampler.batch_size#
- ClassSampler.mask#
- ClassSampler.rng#
The random number generator used by this sampler.
- ClassSampler.shuffle#
Methods#
- ClassSampler.n_batches(n_obs)#
Return the number of batches.
- ClassSampler.n_iters(n_obs)#
Return the number of batches.
Deprecated since version 0.2.0: Use
n_batches()instead.- Return type:
- ClassSampler.sample(n_obs)#
Sample load requests given the total number of observations.
Base implementation simply calls
validate()and then yields via_sample().- Parameters:
- n_obs
int The total number of observations available.
- n_obs
- Yields:
LoadRequest – Load requests for batching data.
- Return type: