Source code for torch_brain.samplers.distributed_stitching_fixed_window

from typing import List, Dict, Tuple

import torch
import torch.distributed as dist

from temporaldata import Interval
from torch_brain.dataset import DatasetIndex


[docs] class DistributedStitchingFixedWindowSampler(torch.utils.data.DistributedSampler): r"""Distributed sliding-window sampler that co-locates windows for prediction stitching. This sampler is designed for distributed evaluation with overlapping windows, where predictions from adjacent windows must later be stitched together. It assigns all windows from the same contiguous interval to the *same* rank, so stitching can be performed locally without any cross-rank communication. In addition to the window indices, the sampler exposes a :attr:`sequence_index` tensor that maps each window to its parent interval. A downstream stitcher can use this to detect when all windows of a sequence have been processed and immediately stitch and evaluate that sequence, keeping peak memory low. Intervals are assigned to ranks using a greedy load-balancing heuristic (largest interval first) so that the number of windows per rank stays as equal as possible. .. note:: :obj:`step` must be ``<= window_length``. Use a smaller step to create overlapping windows for smoother stitched predictions. Args: sampling_intervals: Sampling intervals for each session. Typically obtained from :meth:`~torch_brain.dataset.Dataset.get_sampling_intervals`. batch_size: Number of windows per batch, used by the stitcher to track sequence boundaries within a batch. window_length: Duration of each sliding window in seconds. step: Stride between consecutive windows in seconds. If ``None`` (default), sets :obj:`step` to :obj:`window_length` (non-overlapping windows). num_replicas: Total number of processes. If ``None`` (default), resolved from :func:`torch.distributed.get_world_size`. rank: Rank of the current process. If ``None`` (default), resolved from :func:`torch.distributed.get_rank`. Attributes: sequence_index (torch.Tensor): 1-D integer tensor of length ``len(self)`` mapping each window index to its contiguous-interval index on this rank. Consecutive windows sharing the same value belong to the same interval and will be stitched together. Example:: >>> import numpy as np >>> from temporaldata import Interval >>> from torch_brain.samplers import DistributedStitchingFixedWindowSampler >>> sampling_intervals = { ... "session_1": Interval(0.0, 100.0), ... } >>> sampler = DistributedStitchingFixedWindowSampler( ... sampling_intervals=sampling_intervals, ... batch_size=8, ... window_length=2.0, ... step=1.0, ... num_replicas=1, ... rank=0, ... ) >>> len(sampler) 99 """ def __init__( self, *, sampling_intervals: Dict[str, Interval], batch_size: int, window_length: float, step: float | None = None, num_replicas: int | None = None, rank: int | None = None, ): if window_length <= 0: raise ValueError("window_length must be greater than 0.") if num_replicas is None: if not dist.is_available(): raise RuntimeError("Requires distributed package to be available") num_replicas = dist.get_world_size() if rank is None: if not dist.is_available(): raise RuntimeError("Requires distributed package to be available") rank = dist.get_rank() if rank >= num_replicas or rank < 0: raise ValueError( f"Invalid rank {rank}, rank should be in the interval [0, {num_replicas - 1}]" ) self.sampling_intervals = sampling_intervals self.batch_size = batch_size self.window_length = window_length self.step = window_length if step is None else step self.num_replicas = num_replicas self.rank = rank self.epoch = 0 if self.step <= 0: raise ValueError("Step must be greater than 0.") if self.step > self.window_length: raise ValueError("Step must be less than or equal to window_length.") # Generate indices for this rank self.indices, self.sequence_index = self._generate_indices() self.num_samples = len(self.indices) def _generate_indices(self) -> Tuple[List[DatasetIndex], torch.Tensor]: """Build window indices for this rank using a greedy load-balancing assignment. Intervals are sorted by window count (largest first) and assigned to the rank with the fewest windows so far. All windows within an interval are kept together on the assigned rank to enable local stitching without cross-rank communication. Returns: Tuple of ``(indices, sequence_index)`` where ``indices`` is the list of :class:`~torch_brain.dataset.DatasetIndex` objects for this rank and ``sequence_index`` is a :class:`torch.Tensor` mapping each window to its parent interval index. """ # first, we will compute the number of contiguous windows across all intervals all_intervals = [] interval_sizes = [] for session_name, intervals in self.sampling_intervals.items(): for start, end in zip(intervals.start, intervals.end): if end - start >= self.window_length: # calculate number of windows in this interval num_windows = ( int((end - start - self.window_length + 1e-9) / self.step) + 1 ) if num_windows > 0: interval_sizes.append(num_windows) all_intervals.append((session_name, start, end)) # sort intervals by size in descending order for better load balancing sorted_indices = torch.argsort(torch.tensor(interval_sizes), descending=True) all_intervals = [all_intervals[i] for i in sorted_indices] interval_sizes = [interval_sizes[i] for i in sorted_indices] # track total windows per rank for load balancing rank_sizes = [0] * self.num_replicas # assign intervals to ranks to minimize imbalance indices_list = [] for session_name, start, end in all_intervals: # assign to rank with fewest windows target_rank = min(range(self.num_replicas), key=lambda r: rank_sizes[r]) indices = [] # generate all windows for this interval for t in torch.arange( start, end - self.window_length + 1e-9, self.step, dtype=torch.float64, ): t = t.item() indices.append(DatasetIndex(session_name, t, t + self.window_length)) # add final window if needed last_start = indices[-1].start if indices else start if last_start + self.window_length < end: indices.append( DatasetIndex(session_name, end - self.window_length, end) ) if target_rank == self.rank: # only add indices to this rank indices_list.append(indices) rank_sizes[target_rank] += len(indices) # shuffle indices for this rank indices_list = [indices_list[i] for i in torch.randperm(len(indices_list))] indices = [item for sublist in indices_list for item in sublist] sequence_index = torch.tensor( [i for i, sublist in enumerate(indices_list) for _ in sublist] ) return indices, sequence_index def __iter__(self): r"""Yields :class:`~torch_brain.dataset.DatasetIndex` objects assigned to this rank.""" return iter(self.indices) def __len__(self) -> int: r"""Returns the number of windows assigned to this rank.""" return self.num_samples
[docs] def set_epoch(self, epoch: int) -> None: """Store the current epoch number for API compatibility. This sampler is deterministic and does not re-shuffle on each epoch, so calling this method has no effect on the produced indices. It is provided for compatibility with training loops that call ``sampler.set_epoch(epoch)`` unconditionally. Args: epoch: The epoch number to record. """ self.epoch = epoch