Source code for torch_brain.dataset.mixins

import numpy as np
import pandas as pd
from temporaldata import Data

from torch_brain.utils import np_string_prefix


[docs] class SpikingDatasetMixin: """ Mixin class for :class:`torch_brain.dataset.Dataset` subclasses containing spiking data. Provides: - ``get_unit_ids()`` for retrieving IDs of all included units. - If the class attribute ``spiking_dataset_mixin_uniquify_unit_ids`` is set to ``True``, unit IDs will be made unique across recordings by prefixing each unit ID with the corresponding ``session.id``. This helps avoid collisions when combining data from multiple sessions. (default: ``False``) """ spiking_dataset_mixin_uniquify_unit_ids: bool = False def get_recording_hook(self, data: Data): if self.spiking_dataset_mixin_uniquify_unit_ids: data.units.id = np_string_prefix( f"{data.session.id}/", data.units.id.astype(str), ) super().get_recording_hook(data)
[docs] def get_unit_ids(self) -> list[str]: """Return a sorted list of all unit IDs across all recordings in the dataset.""" ans = [self.get_recording(rid).units.id for rid in self.recording_ids] return np.sort(np.concatenate(ans)).tolist()
[docs] def compute_average_firing_rates(self) -> pd.DataFrame: """ Compute and return the average firing rates for all units in the dataset. Returns: pd.DataFrame: DataFrame indexed by unit ID, containing a column 'firing_rate' (Hz) with the average firing rate for each unit in the dataset. """ unit_ids = [] firing_rates = [] for rid in self.recording_ids: data = self.get_recording(rid) total_time = (data.spikes.domain.end - data.spikes.domain.start).sum() idx, counts = np.unique(data.spikes.unit_index, return_counts=True) fr = np.zeros(len(data.units)) fr[idx] = counts / total_time unit_ids.append(data.units.id) firing_rates.append(fr) unit_ids = np.concatenate(unit_ids) firing_rates = np.concatenate(firing_rates) df = pd.DataFrame({"firing_rate": firing_rates}, index=unit_ids) df.index.name = "unit_id" return df
[docs] class CalciumImagingDatasetMixin: """ Mixin class for :class:`torch_brain.dataset.Dataset` subclasses containing calcium imaging data. Provides: - ``get_roi_ids()`` for retrieving IDs of all included ROIs. - If the class attribute ``calcium_imaging_dataset_mixin_uniquify_roi_ids`` is set to ``True``, ROI IDs will be made unique across recordings by prefixing each ROI ID with the corresponding ``session.id``. This helps avoid collisions when combining data from multiple sessions. (default: ``False``) """ calcium_imaging_dataset_mixin_uniquify_roi_ids: bool = False def get_recording_hook(self, data: Data): if self.calcium_imaging_dataset_mixin_uniquify_roi_ids: data.rois.id = np_string_prefix( f"{data.session.id}/", data.rois.id.astype(str), ) super().get_recording_hook(data)
[docs] def get_roi_ids(self) -> list[str]: """Return a sorted list of all ROI IDs across all recordings in the dataset.""" ans = [self.get_recording(rid).rois.id for rid in self.recording_ids] return np.sort(np.concatenate(ans)).tolist()
[docs] class MultiChannelDatasetMixin: """ Mixin class for :class:`torch_brain.dataset.Dataset` subclasses containing multi-channel recordings (e.g., EEG, ECoG, EMG, sEEG, etc). Provides: - ``get_channel_ids()`` for retrieving sorted channel IDs from recording views returned by ``get_recording(...)``. - Configurable channel-ID uniquification by prepending metadata components before each channel id: ``multichannel_dataset_mixin_uniquify_channel_ids_with_session`` prepends ``session.id`` (default ``False``) and ``multichannel_dataset_mixin_uniquify_channel_ids_with_subject`` prepends ``subject.id`` (default ``True``). This ``subject.id`` uniquification allows channels with the same name in the same subject to be treated as the same channel across sessions. If both are enabled, the prefix order is ``subject.id/session.id``. """ # Channel-ID uniquification toggles used by get_recording_hook. # Prefix order is always subject/session when enabled. multichannel_dataset_mixin_uniquify_channel_ids_with_subject: bool = True multichannel_dataset_mixin_uniquify_channel_ids_with_session: bool = False def get_recording_hook(self, data: Data): prefix = self._build_multichannel_channel_id_prefix(data) if prefix: data.channels.id = np_string_prefix( prefix, data.channels.id.astype(str), ) super().get_recording_hook(data) def _build_multichannel_channel_id_prefix(self, data: Data) -> str: prefix = "" if self.multichannel_dataset_mixin_uniquify_channel_ids_with_subject: prefix += f"{data.subject.id}/" if self.multichannel_dataset_mixin_uniquify_channel_ids_with_session: prefix += f"{data.session.id}/" return prefix
[docs] def get_channel_ids(self) -> list[str]: """Return sorted channel IDs across recordings. ``get_channel_ids`` aggregates ``rec.channels.id`` from ``get_recording(...)``. Any subject/session uniquification is applied there according to ``multichannel_dataset_mixin_uniquify_channel_ids_with_subject`` and ``multichannel_dataset_mixin_uniquify_channel_ids_with_session``. """ ans = [self.get_recording(rid).channels.id for rid in self.recording_ids] return np.sort(np.concatenate(ans)).tolist()