DistributedStitchingFixedWindowSampler#
- class torch_brain.samplers.DistributedStitchingFixedWindowSampler(*, sampling_intervals, batch_size, window_length, step=None, num_replicas=None, rank=None)[source]#
Bases:
torch.utils.data.distributed.DistributedSamplerDistributed sliding-window sampler that co-locates windows for prediction stitching.
This sampler is designed for distributed evaluation with overlapping windows, where predictions from adjacent windows must later be stitched together. It assigns all windows from the same contiguous interval to the same rank, so stitching can be performed locally without any cross-rank communication.
In addition to the window indices, the sampler exposes a
sequence_indextensor that maps each window to its parent interval. A downstream stitcher can use this to detect when all windows of a sequence have been processed and immediately stitch and evaluate that sequence, keeping peak memory low.Intervals are assigned to ranks using a greedy load-balancing heuristic (largest interval first) so that the number of windows per rank stays as equal as possible.
Note
stepmust be<= window_length. Use a smaller step to create overlapping windows for smoother stitched predictions.- Parameters:
sampling_intervals (
Dict[str,Interval]) – Sampling intervals for each session. Typically obtained fromget_sampling_intervals().batch_size (
int) – Number of windows per batch, used by the stitcher to track sequence boundaries within a batch.window_length (
float) – Duration of each sliding window in seconds.step (
Optional[float]) – Stride between consecutive windows in seconds. IfNone(default), setssteptowindow_length(non-overlapping windows).num_replicas (
Optional[int]) – Total number of processes. IfNone(default), resolved fromtorch.distributed.get_world_size().rank (
Optional[int]) – Rank of the current process. IfNone(default), resolved fromtorch.distributed.get_rank().
- sequence_index#
1-D integer tensor of length
len(self)mapping each window index to its contiguous-interval index on this rank. Consecutive windows sharing the same value belong to the same interval and will be stitched together.- Type:
Example:
>>> import numpy as np >>> from temporaldata import Interval >>> from torch_brain.samplers import DistributedStitchingFixedWindowSampler >>> sampling_intervals = { ... "session_1": Interval(0.0, 100.0), ... } >>> sampler = DistributedStitchingFixedWindowSampler( ... sampling_intervals=sampling_intervals, ... batch_size=8, ... window_length=2.0, ... step=1.0, ... num_replicas=1, ... rank=0, ... ) >>> len(sampler) 99
- set_epoch(epoch)[source]#
Store the current epoch number for API compatibility.
This sampler is deterministic and does not re-shuffle on each epoch, so calling this method has no effect on the produced indices. It is provided for compatibility with training loops that call
sampler.set_epoch(epoch)unconditionally.