Source code for torch_brain.datasets.PeiPandarinathNLB2021
from collections.abc import Callable
from pathlib import Path
from typing import Literal
from ._utils import get_processed_dir
from .dataset import Dataset
from .mixins import SpikingDatasetMixin
[docs]
class PeiPandarinathNLB2021(SpikingDatasetMixin, Dataset):
"""
Curated spiking neural activity datasets from the Neural Latents Benchmark
2021 (NLB'21).
.. admonition:: Preprocessing
To download and prepare this dataset, run
.. code:: shell
brainsets prepare pei_pandarinath_nlb_2021
"""
def __init__(
self,
root: str | None = None,
recording_ids: list[str] | None = None,
transform: Callable | None = None,
dirname: str = "pei_pandarinath_nlb_2021",
**kwargs,
):
if root is None:
root = get_processed_dir()
super().__init__(
dataset_dir=Path(root) / dirname,
recording_ids=recording_ids,
transform=transform,
namespace_attributes=["session.id", "units.id"],
**kwargs,
)
self.spiking_dataset_mixin_uniquify_unit_ids = True
[docs]
def get_sampling_intervals(
self,
split: Literal["train", "valid", "test"] | None = None,
):
domain_key = "domain" if split is None else f"{split}_domain"
return {
rid: getattr(self.get_recording(rid), domain_key)
for rid in self.recording_ids
}