Source code for torch_brain.samplers.distributed_evaluation_sampler
import logging
import torch
[docs]
class DistributedEvaluationSamplerWrapper(torch.utils.data.Sampler):
r"""Wraps any sampler for use in distributed evaluation without dropping samples.
Unlike the standard distributed samplers from PyTorch and PyTorch Lightning, which
ensure equal steps per rank by potentially dropping samples, this wrapper preserves
*all* samples by interleaving them across ranks. This guarantees that evaluation
metrics are computed over the complete dataset.
.. warning::
Because this wrapper does not pad to equal length, some ranks will perform more
steps than others. There must be no inter-rank communication (e.g. allreduce)
during the evaluation loop — only barrier-style synchronization at the start
and end is safe.
Rank and world-size are intentionally **not** resolved in ``__init__``; call
:meth:`set_params` once the distributed environment is initialised before iterating.
Args:
sampler: The base sampler whose indices will be
distributed across ranks.
num_replicas: Total number of processes participating in
evaluation. If ``None``, must be set later via :meth:`set_params`.
rank: Rank of the current process. If ``None``, must be set
later via :meth:`set_params`.
Example::
>>> import numpy as np
>>> from temporaldata import Interval
>>> from torch_brain.samplers import SequentialFixedWindowSampler, DistributedEvaluationSamplerWrapper
>>> sampling_intervals = {
... "session_1": Interval(0.0, 100.0),
... "session_2": Interval(0.0, 100.0),
... }
>>> sampler = SequentialFixedWindowSampler(
... sampling_intervals=sampling_intervals,
... window_length=10.0,
... )
>>> dist_sampler = DistributedEvaluationSamplerWrapper(sampler)
>>> dist_sampler.set_params(num_replicas=4, rank=0)
"""
def __init__(
self,
sampler: torch.utils.data.Sampler,
num_replicas: int | None = None,
rank: int | None = None,
):
self.sampler = sampler
self.num_replicas = num_replicas
self.rank = rank
[docs]
def set_params(self, num_replicas: int, rank: int) -> None:
"""Configure distributed parameters after the process group has been initialised.
Args:
num_replicas: Total number of processes participating in evaluation.
rank: Rank of the current process within the process group.
"""
logging.info(
f"Setting distributed sampler params: "
f"num_replicas={num_replicas}, rank={rank}"
)
self.num_replicas = num_replicas
self.rank = rank
def _check_params(self):
return (self.num_replicas is not None) and (self.rank is not None)
[docs]
def rank_len(self):
r"""Returns the number of samples assigned to the current process."""
if self.num_replicas is None or self.rank is None:
raise RuntimeError(
"num_replicas and rank must be set before calling rank_len(). "
"Call set_params() first."
)
total_len = len(self.sampler)
evenly_split = total_len // self.num_replicas
extra = int(self.rank < (total_len % self.num_replicas))
return evenly_split + extra
def __len__(self):
r"""Returns the number of samples assigned to the current process if
the rank and num_replicas are set. Otherwise, returns the total number
of samples in the original sampler.
"""
if not self._check_params():
return len(self.sampler)
else:
return self.rank_len()
def __iter__(self):
r"""Yields the subset of indices assigned to :attr:`rank` via strided interleaving.
Raises:
AssertionError: If :meth:`set_params` has not been called yet.
"""
assert (
self._check_params()
), "Rank and num_replicas must be set before using the distributed sampler."
indices = list(self.sampler)
indices = indices[self.rank : len(indices) : self.num_replicas]
return iter(indices)