DistributedStitchingFixedWindowSampler#

class torch_brain.samplers.DistributedStitchingFixedWindowSampler(*, sampling_intervals, batch_size, window_length, step=None, num_replicas=None, rank=None)[source]#

Bases: torch.utils.data.distributed.DistributedSampler

Distributed sliding-window sampler that co-locates windows for prediction stitching.

This sampler is designed for distributed evaluation with overlapping windows, where predictions from adjacent windows must later be stitched together. It assigns all windows from the same contiguous interval to the same rank, so stitching can be performed locally without any cross-rank communication.

In addition to the window indices, the sampler exposes a sequence_index tensor that maps each window to its parent interval. A downstream stitcher can use this to detect when all windows of a sequence have been processed and immediately stitch and evaluate that sequence, keeping peak memory low.

Intervals are assigned to ranks using a greedy load-balancing heuristic (largest interval first) so that the number of windows per rank stays as equal as possible.

Note

step must be <= window_length. Use a smaller step to create overlapping windows for smoother stitched predictions.

Parameters:
sequence_index#

1-D integer tensor of length len(self) mapping each window index to its contiguous-interval index on this rank. Consecutive windows sharing the same value belong to the same interval and will be stitched together.

Type:

torch.Tensor

Example:

>>> import numpy as np
>>> from temporaldata import Interval
>>> from torch_brain.samplers import DistributedStitchingFixedWindowSampler

>>> sampling_intervals = {
...     "session_1": Interval(0.0, 100.0),
... }
>>> sampler = DistributedStitchingFixedWindowSampler(
...     sampling_intervals=sampling_intervals,
...     batch_size=8,
...     window_length=2.0,
...     step=1.0,
...     num_replicas=1,
...     rank=0,
... )
>>> len(sampler)
99
set_epoch(epoch)[source]#

Store the current epoch number for API compatibility.

This sampler is deterministic and does not re-shuffle on each epoch, so calling this method has no effect on the produced indices. It is provided for compatibility with training loops that call sampler.set_epoch(epoch) unconditionally.

Parameters:

epoch (int) – The epoch number to record.

Return type:

None