Open In Colab

A Minimal Training Setup on NLB Maze with torch_brain#

This example walks through a minimal training pipeline for decoding 2D hand velocity from motor cortex spiking activity, using the “jenkins_maze_train” recording from the Neural Latents Benchmark (NLB) MC_Maze dataset.

It is intended as a starting point for new users of torch_brain and brainsets, and shows how to:

  1. Build a custom Dataset on top of a brainsets recording.

  2. Sample fixed-length trials around a behavioral event using TrialSampler.

  3. Train one of three small decoders (a linear readout, a bidirectional GRU, or a dilated TCN).

Note: Although this notebook will run on a CPU, it is recommended that you use a GPU runtime. If you’re on Google Colab, do: Runtime > Change runtime type > T4 GPU

Setup#

Install dependencies:

!pip install scikit-learn matplotlib
!pip install git+https://github.com/neuro-galaxy/brainsets git+https://github.com/neuro-galaxy/torch_brain

Hide code cell output

Requirement already satisfied: scikit-learn in /usr/local/lib/python3.12/dist-packages (1.6.1)
Requirement already satisfied: matplotlib in /usr/local/lib/python3.12/dist-packages (3.10.0)
Requirement already satisfied: numpy>=1.19.5 in /usr/local/lib/python3.12/dist-packages (from scikit-learn) (2.0.2)
Requirement already satisfied: scipy>=1.6.0 in /usr/local/lib/python3.12/dist-packages (from scikit-learn) (1.16.3)
Requirement already satisfied: joblib>=1.2.0 in /usr/local/lib/python3.12/dist-packages (from scikit-learn) (1.5.3)
Requirement already satisfied: threadpoolctl>=3.1.0 in /usr/local/lib/python3.12/dist-packages (from scikit-learn) (3.6.0)
Requirement already satisfied: contourpy>=1.0.1 in /usr/local/lib/python3.12/dist-packages (from matplotlib) (1.3.3)
Requirement already satisfied: cycler>=0.10 in /usr/local/lib/python3.12/dist-packages (from matplotlib) (0.12.1)
Requirement already satisfied: fonttools>=4.22.0 in /usr/local/lib/python3.12/dist-packages (from matplotlib) (4.62.1)
Requirement already satisfied: kiwisolver>=1.3.1 in /usr/local/lib/python3.12/dist-packages (from matplotlib) (1.5.0)
Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.12/dist-packages (from matplotlib) (26.1)
Requirement already satisfied: pillow>=8 in /usr/local/lib/python3.12/dist-packages (from matplotlib) (11.3.0)
Requirement already satisfied: pyparsing>=2.3.1 in /usr/local/lib/python3.12/dist-packages (from matplotlib) (3.3.2)
Requirement already satisfied: python-dateutil>=2.7 in /usr/local/lib/python3.12/dist-packages (from matplotlib) (2.9.0.post0)
Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.12/dist-packages (from python-dateutil>=2.7->matplotlib) (1.17.0)
Collecting git+https://github.com/neuro-galaxy/brainsets
  Cloning https://github.com/neuro-galaxy/brainsets to /tmp/pip-req-build-uc0rfxgd
  Running command git clone --filter=blob:none --quiet https://github.com/neuro-galaxy/brainsets /tmp/pip-req-build-uc0rfxgd
  Resolved https://github.com/neuro-galaxy/brainsets to commit 74d2d357da15287b03f16a656f478f26242eb07a
  Installing build dependencies ... ?25l?25hdone
  Getting requirements to build wheel ... ?25l?25hdone
  Preparing metadata (pyproject.toml) ... ?25l?25hdone
Collecting git+https://github.com/neuro-galaxy/torch_brain
  Cloning https://github.com/neuro-galaxy/torch_brain to /tmp/pip-req-build-hdia4tek
  Running command git clone --filter=blob:none --quiet https://github.com/neuro-galaxy/torch_brain /tmp/pip-req-build-hdia4tek
  Resolved https://github.com/neuro-galaxy/torch_brain to commit 4c763ccc712cb19db0daee7f884443bb249dbe87
  Installing build dependencies ... ?25l?25hdone
  Getting requirements to build wheel ... ?25l?25hdone
  Preparing metadata (pyproject.toml) ... ?25l?25hdone
Collecting temporaldata (from brainsets==0.2.1.dev33+g74d2d357d)
  Downloading temporaldata-0.1.4-py3-none-any.whl.metadata (3.4 kB)
Requirement already satisfied: scipy>=1.10.1 in /usr/local/lib/python3.12/dist-packages (from brainsets==0.2.1.dev33+g74d2d357d) (1.16.3)
Collecting pynwb>=2.2.0 (from brainsets==0.2.1.dev33+g74d2d357d)
  Downloading pynwb-3.1.3-py3-none-any.whl.metadata (9.2 kB)
Requirement already satisfied: setuptools>=60.2.0 in /usr/local/lib/python3.12/dist-packages (from brainsets==0.2.1.dev33+g74d2d357d) (75.2.0)
Requirement already satisfied: pandas>=1.5.3 in /usr/local/lib/python3.12/dist-packages (from brainsets==0.2.1.dev33+g74d2d357d) (2.2.2)
Requirement already satisfied: jsonschema>=4.21.1 in /usr/local/lib/python3.12/dist-packages (from brainsets==0.2.1.dev33+g74d2d357d) (4.26.0)
Requirement already satisfied: scikit-image>=0.19.3 in /usr/local/lib/python3.12/dist-packages (from brainsets==0.2.1.dev33+g74d2d357d) (0.25.2)
Requirement already satisfied: tqdm>=4.64.1 in /usr/local/lib/python3.12/dist-packages (from brainsets==0.2.1.dev33+g74d2d357d) (4.67.3)
Requirement already satisfied: rich>=13.3.2 in /usr/local/lib/python3.12/dist-packages (from brainsets==0.2.1.dev33+g74d2d357d) (13.9.4)
Requirement already satisfied: msgpack>=1.0.5 in /usr/local/lib/python3.12/dist-packages (from brainsets==0.2.1.dev33+g74d2d357d) (1.1.2)
Requirement already satisfied: pydantic>=2.0 in /usr/local/lib/python3.12/dist-packages (from brainsets==0.2.1.dev33+g74d2d357d) (2.12.3)
Requirement already satisfied: click>=8.1.3 in /usr/local/lib/python3.12/dist-packages (from brainsets==0.2.1.dev33+g74d2d357d) (8.3.3)
Requirement already satisfied: pyyaml>=6.0 in /usr/local/lib/python3.12/dist-packages (from brainsets==0.2.1.dev33+g74d2d357d) (6.0.3)
Collecting uv (from brainsets==0.2.1.dev33+g74d2d357d)
  Downloading uv-0.11.15-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (11 kB)
Requirement already satisfied: prompt_toolkit in /usr/local/lib/python3.12/dist-packages (from brainsets==0.2.1.dev33+g74d2d357d) (3.0.52)
Requirement already satisfied: numpy>=1.14.0 in /usr/local/lib/python3.12/dist-packages (from brainsets==0.2.1.dev33+g74d2d357d) (2.0.2)
Collecting pulp<2.8 (from brainsets==0.2.1.dev33+g74d2d357d)
  Downloading PuLP-2.7.0-py3-none-any.whl.metadata (5.1 kB)
Collecting ray>=2.30.0 (from brainsets==0.2.1.dev33+g74d2d357d)
  Downloading ray-2.55.1-cp312-cp312-manylinux2014_x86_64.whl.metadata (21 kB)
Requirement already satisfied: torch~=2.0 in /usr/local/lib/python3.12/dist-packages (from torch_brain==0.1.1.dev54+g4c763ccc7) (2.10.0+cu128)
Collecting einops~=0.6.0 (from torch_brain==0.1.1.dev54+g4c763ccc7)
  Downloading einops-0.6.1-py3-none-any.whl.metadata (12 kB)
Collecting torchtyping~=0.1 (from torch_brain==0.1.1.dev54+g4c763ccc7)
  Downloading torchtyping-0.1.5-py3-none-any.whl.metadata (9.5 kB)
Collecting torchmetrics>=1.6.0 (from torch_brain==0.1.1.dev54+g4c763ccc7)
  Downloading torchmetrics-1.9.0-py3-none-any.whl.metadata (23 kB)
Requirement already satisfied: attrs>=22.2.0 in /usr/local/lib/python3.12/dist-packages (from jsonschema>=4.21.1->brainsets==0.2.1.dev33+g74d2d357d) (26.1.0)
Requirement already satisfied: jsonschema-specifications>=2023.03.6 in /usr/local/lib/python3.12/dist-packages (from jsonschema>=4.21.1->brainsets==0.2.1.dev33+g74d2d357d) (2025.9.1)
Requirement already satisfied: referencing>=0.28.4 in /usr/local/lib/python3.12/dist-packages (from jsonschema>=4.21.1->brainsets==0.2.1.dev33+g74d2d357d) (0.37.0)
Requirement already satisfied: rpds-py>=0.25.0 in /usr/local/lib/python3.12/dist-packages (from jsonschema>=4.21.1->brainsets==0.2.1.dev33+g74d2d357d) (0.30.0)
Requirement already satisfied: python-dateutil>=2.8.2 in /usr/local/lib/python3.12/dist-packages (from pandas>=1.5.3->brainsets==0.2.1.dev33+g74d2d357d) (2.9.0.post0)
Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.12/dist-packages (from pandas>=1.5.3->brainsets==0.2.1.dev33+g74d2d357d) (2025.2)
Requirement already satisfied: tzdata>=2022.7 in /usr/local/lib/python3.12/dist-packages (from pandas>=1.5.3->brainsets==0.2.1.dev33+g74d2d357d) (2026.1)
Requirement already satisfied: annotated-types>=0.6.0 in /usr/local/lib/python3.12/dist-packages (from pydantic>=2.0->brainsets==0.2.1.dev33+g74d2d357d) (0.7.0)
Requirement already satisfied: pydantic-core==2.41.4 in /usr/local/lib/python3.12/dist-packages (from pydantic>=2.0->brainsets==0.2.1.dev33+g74d2d357d) (2.41.4)
Requirement already satisfied: typing-extensions>=4.14.1 in /usr/local/lib/python3.12/dist-packages (from pydantic>=2.0->brainsets==0.2.1.dev33+g74d2d357d) (4.15.0)
Requirement already satisfied: typing-inspection>=0.4.2 in /usr/local/lib/python3.12/dist-packages (from pydantic>=2.0->brainsets==0.2.1.dev33+g74d2d357d) (0.4.2)
Requirement already satisfied: h5py>=3.2.0 in /usr/local/lib/python3.12/dist-packages (from pynwb>=2.2.0->brainsets==0.2.1.dev33+g74d2d357d) (3.16.0)
Collecting hdmf<5,>=4.1.2 (from pynwb>=2.2.0->brainsets==0.2.1.dev33+g74d2d357d)
  Downloading hdmf-4.3.1-py3-none-any.whl.metadata (11 kB)
Requirement already satisfied: platformdirs>=4.1.0 in /usr/local/lib/python3.12/dist-packages (from pynwb>=2.2.0->brainsets==0.2.1.dev33+g74d2d357d) (4.9.6)
Requirement already satisfied: filelock in /usr/local/lib/python3.12/dist-packages (from ray>=2.30.0->brainsets==0.2.1.dev33+g74d2d357d) (3.29.0)
Requirement already satisfied: packaging>=24.2 in /usr/local/lib/python3.12/dist-packages (from ray>=2.30.0->brainsets==0.2.1.dev33+g74d2d357d) (26.1)
Requirement already satisfied: protobuf>=3.20.3 in /usr/local/lib/python3.12/dist-packages (from ray>=2.30.0->brainsets==0.2.1.dev33+g74d2d357d) (5.29.6)
Requirement already satisfied: requests in /usr/local/lib/python3.12/dist-packages (from ray>=2.30.0->brainsets==0.2.1.dev33+g74d2d357d) (2.32.4)
Requirement already satisfied: markdown-it-py>=2.2.0 in /usr/local/lib/python3.12/dist-packages (from rich>=13.3.2->brainsets==0.2.1.dev33+g74d2d357d) (4.0.0)
Requirement already satisfied: pygments<3.0.0,>=2.13.0 in /usr/local/lib/python3.12/dist-packages (from rich>=13.3.2->brainsets==0.2.1.dev33+g74d2d357d) (2.20.0)
Requirement already satisfied: networkx>=3.0 in /usr/local/lib/python3.12/dist-packages (from scikit-image>=0.19.3->brainsets==0.2.1.dev33+g74d2d357d) (3.6.1)
Requirement already satisfied: pillow>=10.1 in /usr/local/lib/python3.12/dist-packages (from scikit-image>=0.19.3->brainsets==0.2.1.dev33+g74d2d357d) (11.3.0)
Requirement already satisfied: imageio!=2.35.0,>=2.33 in /usr/local/lib/python3.12/dist-packages (from scikit-image>=0.19.3->brainsets==0.2.1.dev33+g74d2d357d) (2.37.3)
Requirement already satisfied: tifffile>=2022.8.12 in /usr/local/lib/python3.12/dist-packages (from scikit-image>=0.19.3->brainsets==0.2.1.dev33+g74d2d357d) (2026.4.11)
Requirement already satisfied: lazy-loader>=0.4 in /usr/local/lib/python3.12/dist-packages (from scikit-image>=0.19.3->brainsets==0.2.1.dev33+g74d2d357d) (0.5)
Requirement already satisfied: sympy>=1.13.3 in /usr/local/lib/python3.12/dist-packages (from torch~=2.0->torch_brain==0.1.1.dev54+g4c763ccc7) (1.14.0)
Requirement already satisfied: jinja2 in /usr/local/lib/python3.12/dist-packages (from torch~=2.0->torch_brain==0.1.1.dev54+g4c763ccc7) (3.1.6)
Requirement already satisfied: fsspec>=0.8.5 in /usr/local/lib/python3.12/dist-packages (from torch~=2.0->torch_brain==0.1.1.dev54+g4c763ccc7) (2025.3.0)
Requirement already satisfied: cuda-bindings==12.9.4 in /usr/local/lib/python3.12/dist-packages (from torch~=2.0->torch_brain==0.1.1.dev54+g4c763ccc7) (12.9.4)
Requirement already satisfied: nvidia-cuda-nvrtc-cu12==12.8.93 in /usr/local/lib/python3.12/dist-packages (from torch~=2.0->torch_brain==0.1.1.dev54+g4c763ccc7) (12.8.93)
Requirement already satisfied: nvidia-cuda-runtime-cu12==12.8.90 in /usr/local/lib/python3.12/dist-packages (from torch~=2.0->torch_brain==0.1.1.dev54+g4c763ccc7) (12.8.90)
Requirement already satisfied: nvidia-cuda-cupti-cu12==12.8.90 in /usr/local/lib/python3.12/dist-packages (from torch~=2.0->torch_brain==0.1.1.dev54+g4c763ccc7) (12.8.90)
Requirement already satisfied: nvidia-cudnn-cu12==9.10.2.21 in /usr/local/lib/python3.12/dist-packages (from torch~=2.0->torch_brain==0.1.1.dev54+g4c763ccc7) (9.10.2.21)
Requirement already satisfied: nvidia-cublas-cu12==12.8.4.1 in /usr/local/lib/python3.12/dist-packages (from torch~=2.0->torch_brain==0.1.1.dev54+g4c763ccc7) (12.8.4.1)
Requirement already satisfied: nvidia-cufft-cu12==11.3.3.83 in /usr/local/lib/python3.12/dist-packages (from torch~=2.0->torch_brain==0.1.1.dev54+g4c763ccc7) (11.3.3.83)
Requirement already satisfied: nvidia-curand-cu12==10.3.9.90 in /usr/local/lib/python3.12/dist-packages (from torch~=2.0->torch_brain==0.1.1.dev54+g4c763ccc7) (10.3.9.90)
Requirement already satisfied: nvidia-cusolver-cu12==11.7.3.90 in /usr/local/lib/python3.12/dist-packages (from torch~=2.0->torch_brain==0.1.1.dev54+g4c763ccc7) (11.7.3.90)
Requirement already satisfied: nvidia-cusparse-cu12==12.5.8.93 in /usr/local/lib/python3.12/dist-packages (from torch~=2.0->torch_brain==0.1.1.dev54+g4c763ccc7) (12.5.8.93)
Requirement already satisfied: nvidia-cusparselt-cu12==0.7.1 in /usr/local/lib/python3.12/dist-packages (from torch~=2.0->torch_brain==0.1.1.dev54+g4c763ccc7) (0.7.1)
Requirement already satisfied: nvidia-nccl-cu12==2.27.5 in /usr/local/lib/python3.12/dist-packages (from torch~=2.0->torch_brain==0.1.1.dev54+g4c763ccc7) (2.27.5)
Requirement already satisfied: nvidia-nvshmem-cu12==3.4.5 in /usr/local/lib/python3.12/dist-packages (from torch~=2.0->torch_brain==0.1.1.dev54+g4c763ccc7) (3.4.5)
Requirement already satisfied: nvidia-nvtx-cu12==12.8.90 in /usr/local/lib/python3.12/dist-packages (from torch~=2.0->torch_brain==0.1.1.dev54+g4c763ccc7) (12.8.90)
Requirement already satisfied: nvidia-nvjitlink-cu12==12.8.93 in /usr/local/lib/python3.12/dist-packages (from torch~=2.0->torch_brain==0.1.1.dev54+g4c763ccc7) (12.8.93)
Requirement already satisfied: nvidia-cufile-cu12==1.13.1.3 in /usr/local/lib/python3.12/dist-packages (from torch~=2.0->torch_brain==0.1.1.dev54+g4c763ccc7) (1.13.1.3)
Requirement already satisfied: triton==3.6.0 in /usr/local/lib/python3.12/dist-packages (from torch~=2.0->torch_brain==0.1.1.dev54+g4c763ccc7) (3.6.0)
Requirement already satisfied: cuda-pathfinder~=1.1 in /usr/local/lib/python3.12/dist-packages (from cuda-bindings==12.9.4->torch~=2.0->torch_brain==0.1.1.dev54+g4c763ccc7) (1.5.3)
Collecting lightning-utilities>=0.15.3 (from torchmetrics>=1.6.0->torch_brain==0.1.1.dev54+g4c763ccc7)
  Downloading lightning_utilities-0.15.3-py3-none-any.whl.metadata (5.5 kB)
Collecting typeguard<3,>=2.11.1 (from torchtyping~=0.1->torch_brain==0.1.1.dev54+g4c763ccc7)
  Downloading typeguard-2.13.3-py3-none-any.whl.metadata (3.6 kB)
Requirement already satisfied: wcwidth in /usr/local/lib/python3.12/dist-packages (from prompt_toolkit->brainsets==0.2.1.dev33+g74d2d357d) (0.6.0)
Collecting ruamel-yaml>=0.16 (from hdmf<5,>=4.1.2->pynwb>=2.2.0->brainsets==0.2.1.dev33+g74d2d357d)
  Downloading ruamel_yaml-0.19.1-py3-none-any.whl.metadata (16 kB)
Requirement already satisfied: mdurl~=0.1 in /usr/local/lib/python3.12/dist-packages (from markdown-it-py>=2.2.0->rich>=13.3.2->brainsets==0.2.1.dev33+g74d2d357d) (0.1.2)
Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.12/dist-packages (from python-dateutil>=2.8.2->pandas>=1.5.3->brainsets==0.2.1.dev33+g74d2d357d) (1.17.0)
Requirement already satisfied: mpmath<1.4,>=1.1.0 in /usr/local/lib/python3.12/dist-packages (from sympy>=1.13.3->torch~=2.0->torch_brain==0.1.1.dev54+g4c763ccc7) (1.3.0)
Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.12/dist-packages (from jinja2->torch~=2.0->torch_brain==0.1.1.dev54+g4c763ccc7) (3.0.3)
Requirement already satisfied: charset_normalizer<4,>=2 in /usr/local/lib/python3.12/dist-packages (from requests->ray>=2.30.0->brainsets==0.2.1.dev33+g74d2d357d) (3.4.7)
Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.12/dist-packages (from requests->ray>=2.30.0->brainsets==0.2.1.dev33+g74d2d357d) (3.13)
Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.12/dist-packages (from requests->ray>=2.30.0->brainsets==0.2.1.dev33+g74d2d357d) (2.5.0)
Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.12/dist-packages (from requests->ray>=2.30.0->brainsets==0.2.1.dev33+g74d2d357d) (2026.4.22)
Downloading einops-0.6.1-py3-none-any.whl (42 kB)
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 42.2/42.2 kB 4.4 MB/s eta 0:00:00
?25hDownloading PuLP-2.7.0-py3-none-any.whl (14.3 MB)
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 14.3/14.3 MB 106.3 MB/s eta 0:00:00
?25hDownloading pynwb-3.1.3-py3-none-any.whl (1.4 MB)
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 1.4/1.4 MB 82.8 MB/s eta 0:00:00
?25hDownloading ray-2.55.1-cp312-cp312-manylinux2014_x86_64.whl (73.8 MB)
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 73.8/73.8 MB 11.2 MB/s eta 0:00:00
?25hDownloading temporaldata-0.1.4-py3-none-any.whl (38 kB)
Downloading torchmetrics-1.9.0-py3-none-any.whl (983 kB)
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 983.4/983.4 kB 62.7 MB/s eta 0:00:00
?25hDownloading torchtyping-0.1.5-py3-none-any.whl (17 kB)
Downloading uv-0.11.15-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (24.4 MB)
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 24.4/24.4 MB 83.4 MB/s eta 0:00:00
?25hDownloading hdmf-4.3.1-py3-none-any.whl (341 kB)
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 341.9/341.9 kB 34.0 MB/s eta 0:00:00
?25hDownloading lightning_utilities-0.15.3-py3-none-any.whl (31 kB)
Downloading typeguard-2.13.3-py3-none-any.whl (17 kB)
Downloading ruamel_yaml-0.19.1-py3-none-any.whl (118 kB)
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 118.1/118.1 kB 13.6 MB/s eta 0:00:00
?25hBuilding wheels for collected packages: brainsets, torch_brain
  Building wheel for brainsets (pyproject.toml) ... ?25l?25hdone
  Created wheel for brainsets: filename=brainsets-0.2.1.dev33+g74d2d357d-py3-none-any.whl size=133708 sha256=cf739944612a1149db5b45a56b8cf1400163352ec60a8c7a49a12a4383946bb4
  Stored in directory: /tmp/pip-ephem-wheel-cache-_lq46q_8/wheels/2f/fb/d8/a1ce8463040ba9c4a916db9fe2f44dfa9df9bda940bbbbd6c3
  Building wheel for torch_brain (pyproject.toml) ... ?25l?25hdone
  Created wheel for torch_brain: filename=torch_brain-0.1.1.dev54+g4c763ccc7-py3-none-any.whl size=80356 sha256=03ef83861e3d6fd89d6e94193f1e4ea0052485be19980a24a4b773712ed23b15
  Stored in directory: /tmp/pip-ephem-wheel-cache-_lq46q_8/wheels/83/5b/c2/163ce716c65d7086c1323c898a73b63b13b765b9c7319658da
Successfully built brainsets torch_brain
Installing collected packages: pulp, uv, typeguard, ruamel-yaml, lightning-utilities, einops, temporaldata, torchtyping, torchmetrics, ray, hdmf, torch_brain, pynwb, brainsets
  Attempting uninstall: pulp
    Found existing installation: PuLP 3.3.0
    Uninstalling PuLP-3.3.0:
      Successfully uninstalled PuLP-3.3.0
  Attempting uninstall: typeguard
    Found existing installation: typeguard 4.5.1
    Uninstalling typeguard-4.5.1:
      Successfully uninstalled typeguard-4.5.1
  Attempting uninstall: einops
    Found existing installation: einops 0.8.2
    Uninstalling einops-0.8.2:
      Successfully uninstalled einops-0.8.2
ERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
spopt 0.7.0 requires pulp>=2.8, but you have pulp 2.7.0 which is incompatible.
inflect 7.5.0 requires typeguard>=4.0.1, but you have typeguard 2.13.3 which is incompatible.
Successfully installed brainsets-0.2.1.dev33+g74d2d357d einops-0.6.1 hdmf-4.3.1 lightning-utilities-0.15.3 pulp-2.7.0 pynwb-3.1.3 ray-2.55.1 ruamel-yaml-0.19.1 temporaldata-0.1.4 torch_brain-0.1.1.dev54+g4c763ccc7 torchmetrics-1.9.0 torchtyping-0.1.5 typeguard-2.13.3 uv-0.11.15

And preprocess the dataset using brainsets

!brainsets prepare pei_pandarinath_nlb_2021 --raw-dir data/raw --processed-dir data/processed -s jenkins_maze_train

Hide code cell output

Preparing pei_pandarinath_nlb_2021...
Raw data directory: /content/data/raw
Processed data directory: /content/data/processed
Detected brainsets installation from git+https://github.com/neuro-galaxy/brainsets@74d2d357da15287b03f16a656f478f26242eb07a
Building temporary virtual environment for /usr/local/lib/python3.12/dist-packages/brainsets_pipelines/pei_pandarinath_nlb_2021/pipeline.py
Installed 116 packages in 242ms
Discovered 2 manifest items
[Status] DOWNLOADING
PATH                                                  SIZE    DONE    DONE% CHECKSUM STATUS          MESSAGE
sub-Jenkins_ses-small_desc-train_behavior+ecephys.nwb 29.2 MB 29.2 MB  100%    ok    done                   
Summary:                                              29.2 MB 29.2 MB                1 done                 
                                                              100.00%                                       
[Status] Loading NWB
[Status] Extracting Metadata
[Status] Extracting Spikes
[Status] Extracting Trials
WARNING:root:The ndarrays in column 'target_pos' do not all have the same shape.
WARNING:root:The ndarrays in column 'barrier_pos' do not all have the same shape.
[Status] Creating Splits
[Status] DONE
import numpy as np
import torch
from torch import nn, Tensor
from tqdm.auto import tqdm
import matplotlib.pyplot as plt
# Hyperparameters (feel free to play with these)
BIN_SIZE = 0.01 # seconds
BATCH_SIZE = 8
EPOCHS = 100
LR = 1e-3

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
Using device: cuda

Defining a Simple & Custom Dataset#

Brainsets provides a dataset class for PeiPandarinathNLB2021 (which contains the NLB Maze dataset), which basically handles the file I/O.

We subclass this dataset and (re-)define two things on top of it’s file I/O:

  • get_sampling_intervals: Decides which time windows in the recording count as samples. Here, each sample is a 700 ms window around movement onset, and we split into train/val using the NLB-provided split indicator.

  • __getitem__: Given a time-sliced sample, how one is it turned into model-compatible Tensors.

from typing import Literal
from temporaldata import Interval
from torch_brain.utils import bin_spikes
from torch_brain.dataset import DatasetIndex
from brainsets.datasets import PeiPandarinathNLB2021

class SimpleNLBMazeDataset(PeiPandarinathNLB2021):
    sample_length = 0.7
    out_dim = 2
    out_sampling_rate = 1000.0

    def __init__(self, root: str, split: Literal["train", "val"], bin_size: float):
        # recording_ids picks which session(s) inside the dataset to load.
        # We just want to load the maze_train session
        super().__init__(root=root, recording_ids=["jenkins_maze_train"])

        # This recording only specificies train and validation set
        # and the test set is kept hidden for online evaluation
        assert split in ("train", "val")

        # store some attributes that are useful later
        self.split = split
        self.bin_size = bin_size
        self.out_samples = round(self.sample_length * self.out_sampling_rate)
        self.num_bins = round(self.sample_length / self.bin_size)
        # get_unit_ids returns the list of neurons recorded in this session.
        self.num_units = len(self.get_unit_ids())

    # Contract between Datasets and Samplers:
    # get_sampling_intervals() returns {recording_id: Interval} listing
    # the windows the sampler may draw from.
    # Sampler will emit one DatasetIndex per sample.
    def get_sampling_intervals(self, *_args, **_kwargs):
        rid = self.recording_ids[0]  # since we only have 1 recording
        recording = self.get_recording(rid)

        # Taking trials to be relative to the movement onset time
        # from 250ms before onset to 450ms after onset
        # (as stated in the NLB paper Appendix A.5.1).
        move_onset_times = recording.trials.move_onset_time
        trials = Interval(move_onset_times - 0.25, move_onset_times + 0.45)

        # The NLB dataset also provided us a default assignment of
        # training and validation trials.
        # `.select_by_mask()` is our standard way to filter an Interval
        # down to a subset based on a boolean mask.
        trial_split_indicator = recording.trials.split_indicator.astype(str)
        train_trials = trials.select_by_mask(trial_split_indicator == "train")
        val_trials = trials.select_by_mask(trial_split_indicator == "val")

        if self.split == "train":
            return {rid: train_trials}
        elif self.split == "val":
            return {rid: val_trials}

    # `index` is a DatasetIndex(recording_id, start, end)
    # produced by the sampler.
    def __getitem__(self, index: DatasetIndex):
        # super().__getitem__ returns a sliced view of the recording, with all
        # modalities (.spikes, .units, .hand.vel, ...) already cropped (lazily).
        data = super().__getitem__(index)

        # In this example, we have designed all models to:
        # - take in a Tensor of shape (Number of neurons, Number of bins), and
        # - return a Tensor of shape (Number of output timestep, Output dimension).

        # Spikes are an irregular event stream — bin them into a regular grid.
        X = bin_spikes(data.spikes, num_units=len(data.units), bin_size=self.bin_size)
        X = torch.from_numpy(X).float()  # shape: (num_bins, num_units)

        # Hand velocity is already a regularly-sampled signal, so we just rescale.
        Y = data.hand.vel / 200.0  # appoximate z-score normalization
        Y = torch.from_numpy(Y).float()  # shape: (out_samples, out_dim)
        return X, Y

Creating the Datasets, Samplers, and DataLoaders#

💡 This is where come across the main pattern for creating data pipelines with torch_brain:

  • Dataset tells the sampler where sampling is allowed,

  • Sampler decides what samples to load (by emitting DatasetIndex objects), and

  • DataLoader batches the samples as usual.

from torch_brain.samplers import TrialSampler
from torch.utils.data import DataLoader  # standard PyTorch loader

DATA_ROOT = "data/processed"  # This is where we stored the processed dataset

train_ds = SimpleNLBMazeDataset(DATA_ROOT, split="train", bin_size=BIN_SIZE)
# We want to sample "one-trial-at-a-time", so we use the TrialSampler
train_sampler = TrialSampler(
    sampling_intervals=train_ds.get_sampling_intervals(),
    shuffle=True,
)
# Note the sampler is passed explicitly; it is not the default random/sequential
# sampler PyTorch picks for an indexable dataset.
train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, sampler=train_sampler)
print(f"Number of units: {train_ds.num_units}")
print(f"Number of training samples: {len(train_sampler)}")

# Validation Dataset, Sampler, and DataLoader
val_ds = SimpleNLBMazeDataset(DATA_ROOT, split="val", bin_size=BIN_SIZE)
val_sampler = TrialSampler(sampling_intervals=val_ds.get_sampling_intervals())
val_loader = DataLoader(val_ds, batch_size=BATCH_SIZE, sampler=val_sampler)
print(f"Number of validation samples: {len(val_sampler)}")

print(f"Number of units:    {train_ds.num_units}")
print(f"Bins per sample:    {train_ds.num_bins}  (bin size = {BIN_SIZE}s)")
print(f"Target samples:     {train_ds.out_samples}  (at {train_ds.out_sampling_rate} Hz)")
print(f"Train trials:       {len(train_ds.get_sampling_intervals()[train_ds.recording_ids[0]])}")
print(f"Val trials:         {len(val_ds.get_sampling_intervals()[val_ds.recording_ids[0]])}")
Number of units: 142
Number of training samples: 75
Number of validation samples: 25
Number of units:    142
Bins per sample:    70  (bin size = 0.01s)
Target samples:     700  (at 1000.0 Hz)
Train trials:       75
Val trials:         25

Let’s first peek at a single sample to confirm the shapes match what we expect, and visualize the binned spikes (input) and hand velocity (target) for one trial.

first_sample_index = next(iter(train_sampler))
print(
    f"First sample:\n"
    f"    recording_id: {first_sample_index.recording_id},\n"
    f"    start time: {first_sample_index.start},\n"
    f"    end time: {first_sample_index.end}\n"
)

X, Y = train_ds[first_sample_index]
print(f"X shape: {tuple(X.shape)}  (num_bins, num_units)")
print(f"Y shape: {tuple(Y.shape)}  (out_samples, out_dim)")

fig, axes = plt.subplots(2, 1, figsize=(5, 5))

axes[0].imshow(X.T.numpy(), aspect="auto", cmap="Greys", origin="lower", interpolation="nearest")
axes[0].set_title("Binned spikes (input)")
axes[0].set_xlabel("Time bin")
axes[0].set_ylabel("Unit")

t = np.linspace(-0.25, 0.45, train_ds.out_samples)
axes[1].plot(t, Y[:, 0].numpy(), label="$v_x$")
axes[1].plot(t, Y[:, 1].numpy(), label="$v_y$")
axes[1].axvline(0, color="k", linestyle="--", alpha=0.3, label="movement onset")
axes[1].set_title("Hand velocity (target)")
axes[1].set_xlabel("Time relative to movement onset (s)")
axes[1].set_ylabel("Normalized velocity")
axes[1].legend()

plt.tight_layout()
plt.show()
First sample:
    recording_id: jenkins_maze_train,
    start time: 195.154,
    end time: 195.85399999999998

X shape: (70, 142)  (num_bins, num_units)
Y shape: (700, 2)  (out_samples, out_dim)
../../_images/06f64b388bd33c2a1b6339179a2a376bd5efcd39b2efbe9213e3333dc64d4640.png

The Model#

Three small decoders are defined in the hidden cells below: Linear, GRU, and TCN.

  • Linear: flatten + a single nn.Linear layer.

  • GRU: bidirectional GRU, then a per-timestep linear readout and an interpolation to upsample from num_bins to out_samples.

  • TCN: a stack of dilated 1D convolutions, followed by the same interpolation + readout.

All three follow the same interface:

They take (batch, num_bins, num_units) and return (batch, out_samples, out_dim):

Model Definitions#

Feel free to look around here!

Linear#

Hide code cell content

class Linear(nn.Module):
    def __init__(self, in_units, in_bins, out_dim, out_samples, dropout=0.2):
        super().__init__()
        self.out_dim = out_dim
        self.out_samples = out_samples

        input_size = in_units * in_bins
        output_size = out_dim * out_samples
        self.net = nn.Sequential(
            nn.Dropout(dropout), nn.Linear(input_size, output_size)
        )

    def forward(self, x: Tensor) -> Tensor:
        batch_size = x.size(0)
        y = self.net(x.flatten(start_dim=1))
        y = y.view(batch_size, self.out_samples, self.out_dim)
        return y

GRU#

Hide code cell content

class GRU(nn.Module):
    def __init__(
        self,
        in_units,
        in_bins,
        out_dim,
        out_samples,
        hidden_dim=64,
        num_layers=2,
        bidirectional=True,
        dropout=0.2,
    ):
        super().__init__()
        self.out_dim = out_dim
        self.out_samples = out_samples

        self.gru = nn.GRU(
            input_size=in_units,
            hidden_size=hidden_dim,
            num_layers=num_layers,
            batch_first=True,
            bidirectional=bidirectional,
            dropout=dropout,
        )
        self.readout = nn.Linear(
            in_features=2 * hidden_dim if bidirectional else hidden_dim,
            out_features=out_dim,
        )

    def forward(self, x: Tensor) -> Tensor:
        z, _ = self.gru(x)
        y = self.readout(z)
        y = y.permute(0, 2, 1)  # (B, T, D) ->  (B, D, T)
        y = nn.functional.interpolate(y, self.out_samples, mode="linear")
        y = y.permute(0, 2, 1)  # (B, D, T) -> (B, T, D)
        return y

TCN#

Hide code cell content

class TCN(nn.Module):
    def __init__(
        self,
        in_units,
        in_bins,
        out_dim,
        out_samples,
        hidden_dim=64,
        num_layers=8,
        kernel_size=3,
        dropout=0.2,
    ):
        super().__init__()
        self.out_dim = out_dim
        self.out_samples = out_samples

        layers = []
        in_channels = in_units
        for i in range(num_layers):
            dilation = 2**i
            padding = (kernel_size - 1) * dilation // 2
            layers.append(nn.Dropout(dropout))
            layers.append(
                nn.Conv1d(
                    in_channels,
                    hidden_dim,
                    kernel_size,
                    padding=padding,
                    dilation=dilation,
                )
            )
            layers.append(nn.ReLU())
            in_channels = hidden_dim
        self.net = nn.Sequential(*layers)
        self.readout = nn.Linear(hidden_dim, out_dim)

    def forward(self, x: Tensor) -> Tensor:
        z = x.permute(0, 2, 1)  # (B, T, C) -> (B, C, T)
        z = self.net(z)
        z = nn.functional.interpolate(z, self.out_samples, mode="linear")
        z = z.permute(0, 2, 1)  # (B, C, T) -> (B, T, C)
        y = self.readout(z)
        return y

Instantiating the model#

model = GRU(  # try: Linear, GRU, TCN
    in_units=train_ds.num_units,
    in_bins=train_ds.num_bins,
    out_dim=train_ds.out_dim,
    out_samples=train_ds.out_samples,
).to(device)

num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"\nTrainable parameters: {num_params:,}")
print(model)
Trainable parameters: 154,626
GRU(
  (gru): GRU(142, 64, num_layers=2, batch_first=True, dropout=0.2, bidirectional=True)
  (readout): Linear(in_features=128, out_features=2, bias=True)
)

Training#

A standard PyTorch loop! MSE loss against the hand velocity, AdamW optimizer, R² score on the validation set at the end of each epoch.

from sklearn.metrics import r2_score

optim = torch.optim.AdamW(model.parameters(), lr=LR)

val_r2_history = []

for epoch in (epoch_pbar := tqdm(range(EPOCHS))):
    model.train()
    for X, Y in train_loader:
        X, Y = X.to(device), Y.to(device)
        pred = model(X)
        loss = nn.functional.mse_loss(pred, Y)
        optim.zero_grad()
        loss.backward()
        optim.step()

    with torch.no_grad():
        model.eval()
        preds, targets = [], []
        for X, Y in val_loader:
            X, Y = X.to(device), Y.to(device)
            preds.append(model(X))
            targets.append(Y)
        pred = torch.cat(preds).flatten(0, 1).cpu()
        target = torch.cat(targets).flatten(0, 1).cpu()
        r2 = r2_score(target, pred)
        val_r2_history.append(r2)
        epoch_pbar.set_description(f"Val R²: {r2:.3f}")

Evaluation#

Plot the R² curve over training and compare predicted vs. actual hand velocity on one validation trial.

fig, ax = plt.subplots(figsize=(6, 3))
ax.plot(val_r2_history)
ax.set_xlabel("Epoch")
ax.set_ylabel("Validation R²")
ax.set_title("Validation R² over training")
ax.grid(alpha=0.3)
plt.tight_layout()
plt.show()
../../_images/818fad2f06af607b38d37c856525a9421670fff2bcf3a8b14d09e5e38c5780ef.png

Let’s look at an example how our model’s predictions compare with the ground truth!

model.eval()
with torch.no_grad():
    X, Y = val_ds[next(iter(val_sampler))]
    pred = model(X.unsqueeze(0).to(device)).squeeze(0).cpu()

t = np.linspace(-0.25, 0.45, val_ds.out_samples)
fig, axes = plt.subplots(1, 2, figsize=(10, 4), sharex=False)
names = ["$v_x$", "$v_y$"]

# Plot $v_x$ and $v_y$
for i, name in enumerate(names):
    axes[i].plot(t, Y[:, i].numpy(), label="actual", color="k")
    axes[i].plot(t, pred[:, i].numpy(), label="predicted", color="green")
    axes[i].axvline(0, color="k", linestyle="--", alpha=0.3)
    axes[i].set_xlabel("Time relative to movement onset (s)")
    axes[i].set_ylabel(name)
    axes[i].legend(loc="upper left")
axes[0].set_title("Predicted vs. actual $v_x$", usetex=False)
axes[1].set_title("Predicted vs. actual $v_y$", usetex=False)

plt.tight_layout()
plt.show()
../../_images/422d8957cdfe79288ad451b5fc9ccf68ebb87b4ccd8ff930dbb0ed135edf6dfb.png

Open In Colab