DistributedEvaluationSamplerWrapper#
- class torch_brain.data.sampler.DistributedEvaluationSamplerWrapper(sampler, num_replicas=None, rank=None)[source]#
Bases:
torch.utils.data.sampler.SamplerWraps a sampler to be used in a distributed evaluation setting. Unlike the standard distributed samplers from PyTorch and PyTorch Lightning which ensure equal samples per rank by potentially dropping samples, this sampler preserves all samples by distributing them across ranks without dropping any, which is important to guarantee that evaluation is done on the complete dataset.
Warning
This wrapper assumes that there is no communication between ranks except at the beginning or end of the evaluation, so it is only suitable for standard evaluation. This is because some ranks might end up performing more steps than others.
- Parameters:
sampler (torch.utils.data.Sampler) – The original sampler to wrap.
num_replicas (int) – Number of processes participating in the distributed evaluation.
rank (int) – Rank of the current process.
Example
>>> from torch_brain.data.sampler import SequentialFixedWindowSampler, DistributedEvaluationSamplerWrapper >>> sampling_intervals = { ... "session_1": Interval(0, 100), ... "session_2": Interval(0, 100), ... } >>> sampler = SequentialFixedWindowSampler(sampling_intervals=sampling_intervals, window_length=10) >>> dist_sampler = DistributedEvaluationSamplerWrapper(sampler)