annbatch.samplers.DistributedSampler#
- class annbatch.samplers.DistributedSampler(sampler, *, dist_info, enforce_equal_batches=True)#
Distributed chunk-based sampler that shards data across distributed processes.
Partitions the full observation range into
world_sizecontiguous shards using themaskmechanism ofSampler. Each rank receives a non-overlapping slice of the data. The shard boundaries are computed lazily whenn_obsbecomes known.When
enforce_equal_batchesis True (the default), the per-rank observation count is rounded down to the nearest multiple ofbatch_size, guaranteeing that every rank yields exactly the same number of complete batches.Rank and world size are obtained from
dist_infoat construction time. The corresponding distributed framework must already be initialized.Example
>>> from annbatch.samplers import DistributedSampler, RandomSampler >>> sampler = RandomSampler( ... chunk_size=256, ... preload_nchunks=4, ... batch_size=32, ... )
Using PyTorch distributed
>>> dist_sampler = DistributedSampler(sampler, dist_info="torch")
Using JAX
>>> dist_sampler = DistributedSampler(sampler, dist_info="jax")
Using a custom callable
>>> dist_sampler = DistributedSampler( ... sampler, ... dist_info=lambda: (rank, world_size), ... )
- Parameters:
- sampler
Sampler The
Samplerto distribute.- dist_info
Literal['torch','jax'] |Callable[[],tuple[int,int]] How to obtain rank and world size. Either a string naming a distributed backend (
"torch"or"jax"), or a callable that returns(rank, world_size).- enforce_equal_batches
bool(default:True) If True, round each rank’s observation count down to a multiple of
batch_sizeso that all workers (ranks) yield the same number of batches. Set to False to use the rawn_obs // world_sizesplit, which may result in an uneven number of batches per worker.
- sampler
Attributes table#
The batch size for data loading. |
|
The observation range this sampler operates on. |
|
The random number generator used by this sampler. |
|
Whether data is shuffled. |
Methods table#
Attributes#
- DistributedSampler.batch_size#
- DistributedSampler.mask#
The observation range this sampler operates on.
- DistributedSampler.rng#
The random number generator used by this sampler.
- DistributedSampler.shuffle#
Methods#
- DistributedSampler.n_iters(n_obs)#
Return the number of batches.
- DistributedSampler.sample(n_obs)#
Sample load requests given the total number of observations.
Base implemention simply calls
validate()and then yields via_sample().- Parameters:
- n_obs
int The total number of observations available.
- n_obs
- Yields:
LoadRequest – Load requests for batching data.
- Return type:
- DistributedSampler.validate(n_obs)#
Validate the sampler configuration against the given n_obs.
This method is called at the start of each
sample()call. Override this method to add custom validation for sampler parameters.- Parameters:
- n_obs
int The total number of observations in the loader.
- n_obs
- Raises:
ValueError – If the sampler configuration is invalid for the given n_obs.
- Return type: