Source code for torch_brain.samplers.sequential_fixed_window

import logging
from typing import List, Dict
from functools import cached_property

import torch
from temporaldata import Interval

from torch_brain.dataset import DatasetIndex


[docs] class SequentialFixedWindowSampler(torch.utils.data.Sampler[DatasetIndex]): r"""Samples fixed-length windows sequentially in a deterministic, reproducible order. Given the :obj:`sampling_intervals` dictionary mapping session IDs to :class:`temporaldata.Interval` objects, this sampler produces :class:`~torch_brain.dataset.DatasetIndex` objects in a fixed order. Windows are stepped through each interval using a configurable :obj:`step` size, making this sampler well-suited for evaluation where full coverage and reproducibility are required. If an interval's length is not an exact multiple of :obj:`step`, a final overlapping window is appended to ensure the entire interval is covered. Args: sampling_intervals: Sampling intervals for each session. Typically obtained from :meth:`~torch_brain.dataset.Dataset.get_sampling_intervals`. window_length: Duration of each sampled window in seconds. step: Step size between the start of consecutive windows in seconds. If ``None`` (default), sets :obj:`step` to :obj:`window_length` (non-overlapping windows). drop_short: If ``False`` (default), a :exc:`ValueError` is raised for any short interval. If ``True``, intervals shorter than :obj:`window_length` are silently skipped with a warning logged. Example:: >>> import numpy as np >>> from temporaldata import Interval >>> from torch_brain.samplers import SequentialFixedWindowSampler >>> sampling_intervals = { ... "session_1": Interval(0.0, 100.0), ... "session_2": Interval(0.0, 100.0), ... } >>> sampler = SequentialFixedWindowSampler( ... sampling_intervals=sampling_intervals, ... window_length=10.0, ... step=5.0, ... ) >>> len(sampler) 38 """ def __init__( self, *, sampling_intervals: Dict[str, Interval], window_length: float, step: float | None = None, drop_short: bool = False, ): self.sampling_intervals = sampling_intervals self.window_length = window_length self.step = step or window_length self.drop_short = drop_short assert self.step > 0, "Step must be greater than 0." # we cache the indices since they are deterministic @cached_property def _indices(self) -> List[DatasetIndex]: indices = [] total_short_dropped = 0.0 for session_name, intervals in self.sampling_intervals.items(): for start, end in zip(intervals.start, intervals.end): interval_length = end - start if interval_length < self.window_length: if self.drop_short: total_short_dropped += interval_length continue else: raise ValueError( f"Interval {(start, end)} is too short to sample from. " f"Minimum length is {self.window_length}." ) indices_ = [ DatasetIndex( session_name, t.item(), (t + self.window_length).item() ) for t in torch.arange(start, end, self.step, dtype=torch.float64) if t + self.window_length <= end ] indices.extend(indices_) # we need to make sure that the entire interval is covered if indices_[-1].end < end: indices.append( DatasetIndex(session_name, end - self.window_length, end) ) if self.drop_short and total_short_dropped > 0: num_samples = len(indices) logging.warning( f"Skipping {total_short_dropped} seconds of data due to short " f"intervals. Remaining: {num_samples * self.window_length} seconds." ) if num_samples == 0: raise ValueError("All intervals are too short to sample from.") return indices def __len__(self): r"""Returns the total number of windows across all sessions.""" return len(self._indices) def __iter__(self): r"""Yields :class:`~torch_brain.dataset.DatasetIndex` objects in sequential order.""" yield from self._indices