Mixins

class SpikingDatasetMixin[source]

Bases: object

Mixin class for torch_brain.dataset.Dataset subclasses containing spiking data.

Provides:
  • get_unit_ids() for retrieving IDs of all included units.

  • If the class attribute spiking_dataset_mixin_uniquify_unit_ids is set to True, unit IDs will be made unique across recordings by prefixing each unit ID with the corresponding session.id. This helps avoid collisions when combining data from multiple sessions. (default: False)

spiking_dataset_mixin_uniquify_unit_ids: bool = False
get_recording_hook(data)[source]
get_unit_ids()[source]

Return a sorted list of all unit IDs across all recordings in the dataset.

Return type:

list[str]

compute_average_firing_rates()[source]

Compute and return the average firing rates for all units in the dataset.

Returns:

DataFrame indexed by unit ID, containing a column ‘firing_rate’ (Hz)

with the average firing rate for each unit in the dataset.

Return type:

pd.DataFrame

class CalciumImagingDatasetMixin[source]

Bases: object

Mixin class for torch_brain.dataset.Dataset subclasses containing calcium imaging data.

Provides:
  • get_roi_ids() for retrieving IDs of all included ROIs.

  • If the class attribute calcium_imaging_dataset_mixin_uniquify_roi_ids is set to True, ROI IDs will be made unique across recordings by prefixing each ROI ID with the corresponding session.id. This helps avoid collisions when combining data from multiple sessions. (default: False)

calcium_imaging_dataset_mixin_uniquify_roi_ids: bool = False
get_recording_hook(data)[source]
get_roi_ids()[source]

Return a sorted list of all ROI IDs across all recordings in the dataset.

Return type:

list[str]