DistributedEvaluationSamplerWrapper#

class torch_brain.data.sampler.DistributedEvaluationSamplerWrapper(sampler, num_replicas=None, rank=None)[source]#

Bases: torch.utils.data.sampler.Sampler

Wraps a sampler to be used in a distributed evaluation setting. Unlike the standard distributed samplers from PyTorch and PyTorch Lightning which ensure equal samples per rank by potentially dropping samples, this sampler preserves all samples by distributing them across ranks without dropping any, which is important to guarantee that evaluation is done on the complete dataset.

Warning

This wrapper assumes that there is no communication between ranks except at the beginning or end of the evaluation, so it is only suitable for standard evaluation. This is because some ranks might end up performing more steps than others.

Parameters:
  • sampler (torch.utils.data.Sampler) – The original sampler to wrap.

  • num_replicas (int) – Number of processes participating in the distributed evaluation.

  • rank (int) – Rank of the current process.

Example

>>> from torch_brain.data.sampler import SequentialFixedWindowSampler, DistributedEvaluationSamplerWrapper

>>> sampling_intervals = {
...     "session_1": Interval(0, 100),
...     "session_2": Interval(0, 100),
... }

>>> sampler = SequentialFixedWindowSampler(sampling_intervals=sampling_intervals, window_length=10)
>>> dist_sampler = DistributedEvaluationSamplerWrapper(sampler)
rank_len()[source]#

Returns the number of samples assigned to the current process.