TrialSampler#

class torch_brain.samplers.TrialSampler(*, sampling_intervals, shuffle=False, generator=None)[source]#

Bases: torch.utils.data.sampler.Sampler

Samples complete trial intervals without windowing.

Unlike RandomFixedWindowSampler and SequentialFixedWindowSampler, which slice continuous recordings into fixed-length windows, TrialSampler treats each individual interval in sampling_intervals as a complete trial and yields one DatasetIndex per trial. This is suited for trial-based experimental paradigms where each trial has a well-defined start and end time that should be preserved.

Parameters:
  • sampling_intervals (Dict[str, Interval]) – Sampling intervals for each session. Each individual interval within the session’s temporaldata.Interval object is treated as one trial. Typically obtained from get_sampling_intervals().

  • shuffle (bool) – If False (default), trials are yielded in the order they appear in sampling_intervals. If True, trials are yielded in a randomly shuffled order.

  • generator (Optional[Generator]) – Optional RNG used when shuffle=True. If None (default), uses the default global PyTorch generator.

Example:

>>> from temporaldata import Interval
>>> import numpy as np
>>> from torch_brain.samplers import TrialSampler

>>> sampling_intervals = {
...     "session_1": Interval(
...         start=np.array([0.0, 5.0, 10.0]),
...         end=np.array([2.0, 8.0, 15.0]),
...     ),
... }
>>> sampler = TrialSampler(
...     sampling_intervals=sampling_intervals,
...     shuffle=True,
... )
>>> len(sampler)
3