DistributedEvaluationSamplerWrapper#

class torch_brain.samplers.DistributedEvaluationSamplerWrapper(sampler, num_replicas=None, rank=None)[source]#

Bases: torch.utils.data.sampler.Sampler

Wraps any sampler for use in distributed evaluation without dropping samples.

Unlike the standard distributed samplers from PyTorch and PyTorch Lightning, which ensure equal steps per rank by potentially dropping samples, this wrapper preserves all samples by interleaving them across ranks. This guarantees that evaluation metrics are computed over the complete dataset.

Warning

Because this wrapper does not pad to equal length, some ranks will perform more steps than others. There must be no inter-rank communication (e.g. allreduce) during the evaluation loop — only barrier-style synchronization at the start and end is safe.

Rank and world-size are intentionally not resolved in __init__; call set_params() once the distributed environment is initialised before iterating.

Parameters:
  • sampler (Sampler) – The base sampler whose indices will be distributed across ranks.

  • num_replicas (Optional[int]) – Total number of processes participating in evaluation. If None, must be set later via set_params().

  • rank (Optional[int]) – Rank of the current process. If None, must be set later via set_params().

Example:

>>> import numpy as np
>>> from temporaldata import Interval
>>> from torch_brain.samplers import SequentialFixedWindowSampler, DistributedEvaluationSamplerWrapper

>>> sampling_intervals = {
...     "session_1": Interval(0.0, 100.0),
...     "session_2": Interval(0.0, 100.0),
... }
>>> sampler = SequentialFixedWindowSampler(
...     sampling_intervals=sampling_intervals,
...     window_length=10.0,
... )
>>> dist_sampler = DistributedEvaluationSamplerWrapper(sampler)
>>> dist_sampler.set_params(num_replicas=4, rank=0)
set_params(num_replicas, rank)[source]#

Configure distributed parameters after the process group has been initialised.

Parameters:
  • num_replicas (int) – Total number of processes participating in evaluation.

  • rank (int) – Rank of the current process within the process group.

Return type:

None

rank_len()[source]#

Returns the number of samples assigned to the current process.