RandomFixedWindowSampler#
- class torch_brain.samplers.RandomFixedWindowSampler(*, sampling_intervals, window_length, generator=None, drop_short=True)[source]#
Bases:
torch.utils.data.sampler.SamplerSamples fixed-length windows randomly from a collection of time intervals.
Given the
sampling_intervalsdictionary mapping session IDs totemporaldata.Intervalobjects, this sampler producesDatasetIndexobjects for indexing aDataset. Each call to__iter__()applies a fresh random temporal jitter and re-shuffles the windows, so every epoch explores slightly different positions within each interval.In one epoch, the number of samples generated from a single contiguous interval of length \(L\) is:
\[N = \left\lfloor\frac{L}{\text{window_length}}\right\rfloor\]- Parameters:
sampling_intervals (
Dict[str,Interval]) – Sampling intervals for each session. Typically obtained fromget_sampling_intervals().window_length (
float) – Duration of each sampled window in seconds.generator (
Optional[Generator]) – Optional RNG used for jitter and shuffling. IfNone(default), uses the default global PyTorch generator.drop_short (
bool) – IfTrue(default), intervals shorter thanwindow_lengthare silently skipped with a warning logged. IfFalse, aValueErroris raised for any short interval.
Example:
>>> import numpy as np >>> from temporaldata import Interval >>> from torch_brain.samplers import RandomFixedWindowSampler >>> sampling_intervals = { ... "session_1": Interval(0.0, 100.0), ... "session_2": Interval(0.0, 200.0), ... } >>> sampler = RandomFixedWindowSampler( ... sampling_intervals=sampling_intervals, ... window_length=1.0, ... ) >>> len(sampler) 300