DistributedEvaluationSamplerWrapper#
- class torch_brain.samplers.DistributedEvaluationSamplerWrapper(sampler, num_replicas=None, rank=None)[source]#
Bases:
torch.utils.data.sampler.SamplerWraps any sampler for use in distributed evaluation without dropping samples.
Unlike the standard distributed samplers from PyTorch and PyTorch Lightning, which ensure equal steps per rank by potentially dropping samples, this wrapper preserves all samples by interleaving them across ranks. This guarantees that evaluation metrics are computed over the complete dataset.
Warning
Because this wrapper does not pad to equal length, some ranks will perform more steps than others. There must be no inter-rank communication (e.g. allreduce) during the evaluation loop — only barrier-style synchronization at the start and end is safe.
Rank and world-size are intentionally not resolved in
__init__; callset_params()once the distributed environment is initialised before iterating.- Parameters:
sampler (
Sampler) – The base sampler whose indices will be distributed across ranks.num_replicas (
Optional[int]) – Total number of processes participating in evaluation. IfNone, must be set later viaset_params().rank (
Optional[int]) – Rank of the current process. IfNone, must be set later viaset_params().
Example:
>>> import numpy as np >>> from temporaldata import Interval >>> from torch_brain.samplers import SequentialFixedWindowSampler, DistributedEvaluationSamplerWrapper >>> sampling_intervals = { ... "session_1": Interval(0.0, 100.0), ... "session_2": Interval(0.0, 100.0), ... } >>> sampler = SequentialFixedWindowSampler( ... sampling_intervals=sampling_intervals, ... window_length=10.0, ... ) >>> dist_sampler = DistributedEvaluationSamplerWrapper(sampler) >>> dist_sampler.set_params(num_replicas=4, rank=0)