Source code for torch_brain.samplers.trial_sampler
from typing import Dict
import torch
from torch_brain.dataset import DatasetIndex
from temporaldata import Interval
[docs]
class TrialSampler(torch.utils.data.Sampler[DatasetIndex]):
r"""Samples complete trial intervals without windowing.
Unlike :class:`RandomFixedWindowSampler` and :class:`SequentialFixedWindowSampler`,
which slice continuous recordings into fixed-length windows, :class:`TrialSampler`
treats each individual interval in :obj:`sampling_intervals` as a complete trial and
yields one :class:`~torch_brain.dataset.DatasetIndex` per trial. This is suited for
trial-based experimental paradigms where each trial has a well-defined start and end
time that should be preserved.
Args:
sampling_intervals: Sampling intervals for each session.
Each individual interval within the session's :class:`temporaldata.Interval`
object is treated as one trial. Typically obtained from
:meth:`~torch_brain.dataset.Dataset.get_sampling_intervals`.
shuffle: If ``False`` (default), trials are yielded in the order they appear in
:obj:`sampling_intervals`.
If ``True``, trials are yielded in a randomly shuffled order.
generator: Optional RNG used when
:obj:`shuffle=True`. If ``None`` (default), uses the default global PyTorch generator.
Example::
>>> from temporaldata import Interval
>>> import numpy as np
>>> from torch_brain.samplers import TrialSampler
>>> sampling_intervals = {
... "session_1": Interval(
... start=np.array([0.0, 5.0, 10.0]),
... end=np.array([2.0, 8.0, 15.0]),
... ),
... }
>>> sampler = TrialSampler(
... sampling_intervals=sampling_intervals,
... shuffle=True,
... )
>>> len(sampler)
3
"""
def __init__(
self,
*,
sampling_intervals: Dict[str, Interval],
shuffle: bool = False,
generator: torch.Generator | None = None,
):
self.sampling_intervals = sampling_intervals
self.shuffle = shuffle
self.generator = generator
def __len__(self):
r"""Returns the total number of trials across all sessions."""
return sum(len(intervals) for intervals in self.sampling_intervals.values())
def __iter__(self):
r"""Yields one :class:`~torch_brain.dataset.DatasetIndex` per trial, optionally shuffled."""
indices = [
DatasetIndex(session_id, start, end)
for session_id, intervals in self.sampling_intervals.items()
for start, end in zip(intervals.start, intervals.end)
]
if self.shuffle:
for idx in torch.randperm(len(indices), generator=self.generator):
yield indices[idx]
else:
yield from indices