annbatch.samplers.DistributedSampler#

class annbatch.samplers.DistributedSampler(sampler, *, dist_info, enforce_equal_batches=True)#

Distributed chunk-based sampler that shards data across distributed processes.

Partitions the full observation range into world_size contiguous shards using the mask mechanism of Sampler. Each rank receives a non-overlapping slice of the data. The shard boundaries are computed lazily when n_obs becomes known.

When enforce_equal_batches is True (the default), the per-rank observation count is rounded down to the nearest multiple of batch_size, guaranteeing that every rank yields exactly the same number of complete batches.

Rank and world size are obtained from dist_info at construction time. The corresponding distributed framework must already be initialized.

Example

>>> from annbatch.samplers import DistributedSampler, RandomSampler
>>> sampler = RandomSampler(
...     chunk_size=256,
...     preload_nchunks=4,
...     batch_size=32,
... )

Using PyTorch distributed

>>> dist_sampler = DistributedSampler(sampler, dist_info="torch")

Using JAX

>>> dist_sampler = DistributedSampler(sampler, dist_info="jax")

Using a custom callable

>>> dist_sampler = DistributedSampler(
...     sampler,
...     dist_info=lambda: (rank, world_size),
... )
Parameters:
sampler Sampler

The Sampler to distribute.

dist_info Literal['torch', 'jax'] | Callable[[], tuple[int, int]]

How to obtain rank and world size. Either a string naming a distributed backend ("torch" or "jax"), or a callable that returns (rank, world_size).

enforce_equal_batches bool (default: True)

If True, round each rank’s observation count down to a multiple of batch_size so that all workers (ranks) yield the same number of batches. Set to False to use the raw n_obs // world_size split, which may result in an uneven number of batches per worker.

Attributes table#

batch_size

The batch size for data loading.

mask

The observation range this sampler operates on.

rng

The random number generator used by this sampler.

shuffle

Whether data is shuffled.

Methods table#

n_iters(n_obs)

Return the number of batches.

sample(n_obs)

Sample load requests given the total number of observations.

validate(n_obs)

Validate the sampler configuration against the given n_obs.

Attributes#

DistributedSampler.batch_size#
DistributedSampler.mask#

The observation range this sampler operates on.

DistributedSampler.rng#

The random number generator used by this sampler.

DistributedSampler.shuffle#

Methods#

DistributedSampler.n_iters(n_obs)#

Return the number of batches.

Parameters:
n_obs int

The total number of observations available.

Return type:

int

Returns:

int The total number of batches this sampler will produce.

DistributedSampler.sample(n_obs)#

Sample load requests given the total number of observations.

Base implemention simply calls validate() and then yields via _sample().

Parameters:
n_obs int

The total number of observations available.

Yields:

LoadRequest – Load requests for batching data.

Return type:

Iterator[LoadRequest]

DistributedSampler.validate(n_obs)#

Validate the sampler configuration against the given n_obs.

This method is called at the start of each sample() call. Override this method to add custom validation for sampler parameters.

Parameters:
n_obs int

The total number of observations in the loader.

Raises:

ValueError – If the sampler configuration is invalid for the given n_obs.

Return type:

None