Source code for torch_brain.dataset.mixins
import numpy as np
import pandas as pd
from temporaldata import Data
from torch_brain.utils import np_string_prefix
[docs]
class SpikingDatasetMixin:
"""
Mixin class for :class:`torch_brain.dataset.Dataset` subclasses containing spiking data.
Provides:
- ``get_unit_ids()`` for retrieving IDs of all included units.
- If the class attribute ``spiking_dataset_mixin_uniquify_unit_ids`` is set to ``True``,
unit IDs will be made unique across recordings by prefixing each unit ID with the
corresponding ``session.id``. This helps avoid collisions when combining data from
multiple sessions. (default: ``False``)
"""
spiking_dataset_mixin_uniquify_unit_ids: bool = False
def get_recording_hook(self, data: Data):
if self.spiking_dataset_mixin_uniquify_unit_ids:
data.units.id = np_string_prefix(
f"{data.session.id}/",
data.units.id.astype(str),
)
super().get_recording_hook(data)
[docs]
def get_unit_ids(self) -> list[str]:
"""Return a sorted list of all unit IDs across all recordings in the dataset."""
ans = [self.get_recording(rid).units.id for rid in self.recording_ids]
return np.sort(np.concatenate(ans)).tolist()
[docs]
def compute_average_firing_rates(self) -> pd.DataFrame:
"""
Compute and return the average firing rates for all units in the dataset.
Returns:
pd.DataFrame: DataFrame indexed by unit ID, containing a column 'firing_rate' (Hz)
with the average firing rate for each unit in the dataset.
"""
unit_ids = []
firing_rates = []
for rid in self.recording_ids:
data = self.get_recording(rid)
total_time = (data.spikes.domain.end - data.spikes.domain.start).sum()
idx, counts = np.unique(data.spikes.unit_index, return_counts=True)
fr = np.zeros(len(data.units))
fr[idx] = counts / total_time
unit_ids.append(data.units.id)
firing_rates.append(fr)
unit_ids = np.concatenate(unit_ids)
firing_rates = np.concatenate(firing_rates)
df = pd.DataFrame({"firing_rate": firing_rates}, index=unit_ids)
df.index.name = "unit_id"
return df
[docs]
class CalciumImagingDatasetMixin:
"""
Mixin class for :class:`torch_brain.dataset.Dataset` subclasses containing calcium imaging data.
Provides:
- ``get_roi_ids()`` for retrieving IDs of all included ROIs.
- If the class attribute ``calcium_imaging_dataset_mixin_uniquify_roi_ids`` is set to ``True``,
ROI IDs will be made unique across recordings by prefixing each ROI ID with the
corresponding ``session.id``. This helps avoid collisions when combining data from
multiple sessions. (default: ``False``)
"""
calcium_imaging_dataset_mixin_uniquify_roi_ids: bool = False
def get_recording_hook(self, data: Data):
if self.calcium_imaging_dataset_mixin_uniquify_roi_ids:
data.rois.id = np_string_prefix(
f"{data.session.id}/",
data.rois.id.astype(str),
)
super().get_recording_hook(data)
[docs]
def get_roi_ids(self) -> list[str]:
"""Return a sorted list of all ROI IDs across all recordings in the dataset."""
ans = [self.get_recording(rid).rois.id for rid in self.recording_ids]
return np.sort(np.concatenate(ans)).tolist()
[docs]
class MultiChannelDatasetMixin:
"""
Mixin class for :class:`torch_brain.dataset.Dataset` subclasses containing
multi-channel recordings (e.g., EEG, ECoG, EMG, sEEG, etc).
Provides:
- ``get_channel_ids()`` for retrieving sorted channel IDs from
recording views returned by ``get_recording(...)``.
- Configurable channel-ID uniquification by prepending metadata
components before each channel id:
``multichannel_dataset_mixin_uniquify_channel_ids_with_session``
prepends ``session.id`` (default ``False``) and
``multichannel_dataset_mixin_uniquify_channel_ids_with_subject``
prepends ``subject.id`` (default ``True``). This ``subject.id``
uniquification allows channels with the same name in the same
subject to be treated as the same channel across sessions. If both
are enabled,
the prefix order is ``subject.id/session.id``.
"""
# Channel-ID uniquification toggles used by get_recording_hook.
# Prefix order is always subject/session when enabled.
multichannel_dataset_mixin_uniquify_channel_ids_with_subject: bool = True
multichannel_dataset_mixin_uniquify_channel_ids_with_session: bool = False
def get_recording_hook(self, data: Data):
prefix = self._build_multichannel_channel_id_prefix(data)
if prefix:
data.channels.id = np_string_prefix(
prefix,
data.channels.id.astype(str),
)
super().get_recording_hook(data)
def _build_multichannel_channel_id_prefix(self, data: Data) -> str:
prefix = ""
if self.multichannel_dataset_mixin_uniquify_channel_ids_with_subject:
prefix += f"{data.subject.id}/"
if self.multichannel_dataset_mixin_uniquify_channel_ids_with_session:
prefix += f"{data.session.id}/"
return prefix
[docs]
def get_channel_ids(self) -> list[str]:
"""Return sorted channel IDs across recordings.
``get_channel_ids`` aggregates ``rec.channels.id`` from ``get_recording(...)``.
Any subject/session uniquification is applied there according to
``multichannel_dataset_mixin_uniquify_channel_ids_with_subject`` and
``multichannel_dataset_mixin_uniquify_channel_ids_with_session``.
"""
ans = [self.get_recording(rid).channels.id for rid in self.recording_ids]
return np.sort(np.concatenate(ans)).tolist()