RandomFixedWindowSampler#

class torch_brain.samplers.RandomFixedWindowSampler(*, sampling_intervals, window_length, generator=None, drop_short=True)[source]#

Bases: torch.utils.data.sampler.Sampler

Samples fixed-length windows randomly from a collection of time intervals.

Given the sampling_intervals dictionary mapping session IDs to temporaldata.Interval objects, this sampler produces DatasetIndex objects for indexing a Dataset. Each call to __iter__() applies a fresh random temporal jitter and re-shuffles the windows, so every epoch explores slightly different positions within each interval.

In one epoch, the number of samples generated from a single contiguous interval of length \(L\) is:

\[N = \left\lfloor\frac{L}{\text{window_length}}\right\rfloor\]
Parameters:
  • sampling_intervals (Dict[str, Interval]) – Sampling intervals for each session. Typically obtained from get_sampling_intervals().

  • window_length (float) – Duration of each sampled window in seconds.

  • generator (Optional[Generator]) – Optional RNG used for jitter and shuffling. If None (default), uses the default global PyTorch generator.

  • drop_short (bool) – If True (default), intervals shorter than window_length are silently skipped with a warning logged. If False, a ValueError is raised for any short interval.

Example:

>>> import numpy as np
>>> from temporaldata import Interval
>>> from torch_brain.samplers import RandomFixedWindowSampler

>>> sampling_intervals = {
...     "session_1": Interval(0.0, 100.0),
...     "session_2": Interval(0.0, 200.0),
... }
>>> sampler = RandomFixedWindowSampler(
...     sampling_intervals=sampling_intervals,
...     window_length=1.0,
... )
>>> len(sampler)
300