Source code for torch_brain.samplers.random_fixed_window
import math
import logging
from typing import Dict
from functools import cached_property
import torch
from temporaldata import Interval
from torch_brain.dataset import DatasetIndex
[docs]
class RandomFixedWindowSampler(torch.utils.data.Sampler[DatasetIndex]):
r"""Samples fixed-length windows randomly from a collection of time intervals.
Given the :obj:`sampling_intervals` dictionary mapping session IDs to
:class:`temporaldata.Interval` objects, this sampler produces
:class:`~torch_brain.dataset.DatasetIndex` objects for indexing a
:class:`~torch_brain.dataset.Dataset`. Each call to :meth:`__iter__` applies a
fresh random temporal jitter and re-shuffles the windows, so every epoch explores
slightly different positions within each interval.
In one epoch, the number of samples generated from a single contiguous interval
of length :math:`L` is:
.. math::
N = \left\lfloor\frac{L}{\text{window_length}}\right\rfloor
Args:
sampling_intervals: Sampling intervals for each session.
Typically obtained from
:meth:`~torch_brain.dataset.Dataset.get_sampling_intervals`.
window_length: Duration of each sampled window in seconds.
generator: Optional RNG used for jitter and
shuffling. If ``None`` (default), uses the default global PyTorch generator.
drop_short: If ``True`` (default), intervals shorter than
:obj:`window_length` are silently skipped with a warning logged. If
``False``, a :exc:`ValueError` is raised for any short interval.
Example::
>>> import numpy as np
>>> from temporaldata import Interval
>>> from torch_brain.samplers import RandomFixedWindowSampler
>>> sampling_intervals = {
... "session_1": Interval(0.0, 100.0),
... "session_2": Interval(0.0, 200.0),
... }
>>> sampler = RandomFixedWindowSampler(
... sampling_intervals=sampling_intervals,
... window_length=1.0,
... )
>>> len(sampler)
300
"""
def __init__(
self,
*,
sampling_intervals: Dict[str, Interval],
window_length: float,
generator: torch.Generator | None = None,
drop_short: bool = True,
):
if window_length <= 0:
raise ValueError("window_length must be greater than 0.")
self.sampling_intervals = sampling_intervals
self.window_length = window_length
self.generator = generator
self.drop_short = drop_short
@cached_property
def _estimated_len(self) -> int:
num_samples = 0
total_short_dropped = 0.0
for intervals in self.sampling_intervals.values():
for start, end in zip(intervals.start, intervals.end):
interval_length = end - start
if interval_length < self.window_length:
if self.drop_short:
total_short_dropped += interval_length
continue
else:
raise ValueError(
f"Interval {(start, end)} is too short to sample from. "
f"Minimum length is {self.window_length}."
)
num_samples += math.floor(interval_length / self.window_length)
if self.drop_short and total_short_dropped > 0:
logging.warning(
f"Skipping {total_short_dropped} seconds of data due to short "
f"intervals. Remaining: {num_samples * self.window_length} seconds."
)
if num_samples == 0:
raise ValueError("All intervals are too short to sample from.")
return num_samples
def __len__(self):
r"""Returns the estimated number of samples per epoch across all sessions."""
return self._estimated_len
def __iter__(self):
r"""Yields shuffled :class:`~torch_brain.dataset.DatasetIndex` objects with random temporal jitter."""
if len(self) == 0.0:
raise ValueError("All intervals are too short to sample from.")
indices = []
for session_name, sampling_intervals in self.sampling_intervals.items():
for start, end in sampling_intervals:
interval_length = end - start
if interval_length < self.window_length:
if self.drop_short:
continue
else:
raise ValueError(
f"Interval {(start, end)} is too short to sample from. "
f"Minimum length is {self.window_length}."
)
# sample a random offset
left_offset = (
torch.rand(1, generator=self.generator).item() * self.window_length
)
indices_ = [
DatasetIndex(
session_name, t.item(), (t + self.window_length).item()
)
for t in torch.arange(
start + left_offset,
end,
self.window_length,
dtype=torch.float64,
)
if t + self.window_length <= end
]
if len(indices_) > 0:
indices.extend(indices_)
right_offset = end - indices[-1].end
else:
right_offset = end - start - left_offset
# if there is one sample worth of data, add it
# this ensures that the number of samples is always consistent
if right_offset + left_offset >= self.window_length:
if right_offset > left_offset:
indices.append(
DatasetIndex(session_name, end - self.window_length, end)
)
else:
indices.append(
DatasetIndex(
session_name, start, start + self.window_length
)
)
# shuffle
for idx in torch.randperm(len(indices), generator=self.generator):
yield indices[idx]