OpenNeuroDataset#

class torch_brain.datasets.OpenNeuroDataset(root, dataset_dir, split_type, recording_ids=None, transform=None, uniquify_channel_ids_with_subject=False, uniquify_channel_ids_with_session=True, split_ratios=(0.8, 0.1, 0.1), seed=42, **kwargs)[source]#

Bases: torch_brain.datasets.mixins.MultiChannelDatasetMixin, torch_brain.datasets.dataset.Dataset

Base class for OpenNeuro datasets.

This class provides an interface for loading, representing, and manipulating OpenNeuro datasets using the MultiChannelDatasetMixin and the Dataset interface. It supports various splitting strategies for machine learning workflows, notably ‘intrasession’, ‘intersubject’, and ‘intersession’ splits.

Parameters:
  • root (str) – Root directory containing processed OpenNeuro dataset artifacts.

  • dataset_dir (str) – Relative dataset directory within the root path.

  • split_type (Literal['intrasession', 'intersubject', 'intersession']) – The split strategy to use, must be one of ‘intrasession’, ‘intersubject’, or ‘intersession’.

  • recording_ids (Optional[list[str]]) – List of recording IDs to include, or None to use all available recordings.

  • transform (Optional[Callable]) – Optional sample transform.

  • uniquify_channel_ids_with_subject (bool) – Whether to prefix channel IDs with subject.id via MultiChannelDatasetMixin. Defaults to False.

  • uniquify_channel_ids_with_session (bool) – Whether to prefix channel IDs with session.id via MultiChannelDatasetMixin. Defaults to True.

  • task_paradigm – The task paradigm of the dataset. Depends on the dataset. Defaults to None.

  • split_ratios (tuple[float, float, float]) – Tuple of three floats (train, val, test) whose sum must be 1.0. Specifies the proportion of the dataset to use for the train, validation, and test splits, respectively. All ratios must be in [0, 1] and their sum must be 1.0. If the sum does not equal 1.0, a ValueError is raised.

  • seed (int) – The seed for the random number generator. Used for computing splits in

  • 42. (intersubject and intersession mode. Defaults to)

get_sampling_intervals(split=None)[source]#

Retrieve the sampling intervals for each recording according to the specified split.

If split is None, returns the full interval domain for every recording for unrestricted sampling. If a split (“train”, “val”, or “test”) is provided, returns only the intervals (within each recording) eligible for sampling under the current split type and task paradigm.

The selection of intervals is determined according to: - The current self.split_type (intrasession, intersubject, or intersession). - Whether a self.task_paradigm is specified, which influences the interval extraction.

Parameters:

split (Optional[Literal['train', 'val', 'test']]) – One of “train”, “val”, or “test” to select intervals corresponding to that split, or None to retrieve the entire domain for all recordings.

Return type:

dict[str, Interval]

Returns:

Dictionary mapping recording IDs to their valid Interval objects for sampling in the given split (or full Interval domain if split is None).

Raises:
  • ValueError – If the requested split or the dataset’s split_type is not recognized/supported.

  • KeyError – If a required split or assignment attribute is missing in a recording.

Notes

  • Intervals are defined based on recording domains and split logic.

get_default_sampling_intervals(recording, split)[source]#

Get the default sampling intervals for a given split. These intervals are behavior agnostic, meaning they do not take into account any task or behavioral (event/label) annotations when creating the train, val, and test splits—interval assignment is performed solely based on session or subject, not on in-task structure.

Notes: - For split_type == “intrasession”, intervals are split causally into train, val, and test based on split_ratios. - For split_type == “intersubject” or “intersession”, only the assigned recordings are included for each split (using k-fold assignment); all others return an empty interval.

Return type:

Interval