Source code for torch_brain.models.poyo

import inspect
import logging
from pathlib import Path
from typing import Dict, List

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from temporaldata import Data

from torch_brain.batching import pad8, track_mask8
from torch_brain.dataset import Dataset
from torch_brain.nn import (
    Embedding,
    InfiniteVocabEmbedding,
    RotaryCrossAttention,
    RotarySelfAttention,
    RotaryTimeEmbedding,
)
from torch_brain.utils import (
    create_linspace_latent_tokens,
    create_start_end_unit_tokens,
)


[docs] class POYO(nn.Module): """POYO model from `Azabou et al. 2023, A Unified, Scalable Framework for Neural Population Decoding <https://arxiv.org/abs/2310.16046>`_. POYO is a transformer-based model for neural decoding from electrophysiological recordings. 1. Input tokens are constructed by combining unit embeddings, token type embeddings, and time embeddings for each spike in the sequence. 2. The input sequence is compressed using cross-attention, where learnable latent tokens (each with an associated timestamp) attend to the input tokens. 3. The compressed latent token representations undergo further refinement through multiple self-attention processing layers. 4. Query tokens are constructed for the desired outputs by combining session embeddings, and output timestamps. 5. These query tokens attend to the processed latent representations through cross-attention, producing outputs in the model's dimensional space (dim). 6. Finally, a task-specific linear layer maps the outputs from the model dimension to the appropriate output dimension. Args: sequence_length: Maximum duration of the input spike sequence (in seconds) latent_step: Timestep of the latent grid (in seconds) num_latents_per_step: Number of unique latent tokens (repeated at every latent step) dim: Hidden dimension of the model dim_out: Output dimension of the model depth: Number of processing layers (self-attentions in the latent space) dim_head: Dimension of each attention head cross_heads: Number of attention heads used in a cross-attention layer self_heads: Number of attention heads used in a self-attention layer ffn_dropout: Dropout rate for feed-forward networks lin_dropout: Dropout rate for linear layers atn_dropout: Dropout rate for attention emb_init_scale: Scale for embedding initialization t_min: Minimum timestamp resolution for rotary embeddings t_max: Maximum timestamp resolution for rotary embeddings """ def __init__( self, *, sequence_length: float, latent_step: float, num_latents_per_step: int = 64, dim_out: int, dim: int = 512, depth: int = 2, dim_head: int = 64, cross_heads: int = 1, self_heads: int = 8, ffn_dropout: float = 0.2, lin_dropout: float = 0.4, atn_dropout: float = 0.0, emb_init_scale: float = 0.02, t_min: float = 1e-4, t_max: float = 2.0627, ): super().__init__() self._validate_params(sequence_length, latent_step) self.sequence_length = sequence_length self.latent_step = latent_step self.num_latents_per_step = num_latents_per_step # embeddings self.unit_emb = InfiniteVocabEmbedding(dim, init_scale=emb_init_scale) self.session_emb = InfiniteVocabEmbedding(dim, init_scale=emb_init_scale) self.token_type_emb = Embedding(4, dim, init_scale=emb_init_scale) self.latent_emb = Embedding( num_latents_per_step, dim, init_scale=emb_init_scale ) self.rotary_emb = RotaryTimeEmbedding( head_dim=dim_head, rotate_dim=dim_head // 2, t_min=t_min, t_max=t_max, ) self.dropout = nn.Dropout(p=lin_dropout) # encoder layer self.enc_atn = RotaryCrossAttention( dim=dim, heads=cross_heads, dropout=atn_dropout, dim_head=dim_head, rotate_value=True, ) self.enc_ffn = nn.Sequential( nn.LayerNorm(dim), FeedForward(dim=dim, dropout=ffn_dropout) ) # process layers self.proc_layers = nn.ModuleList([]) for i in range(depth): self.proc_layers.append( nn.Sequential( RotarySelfAttention( dim=dim, heads=self_heads, dropout=atn_dropout, dim_head=dim_head, rotate_value=True, ), nn.Sequential( nn.LayerNorm(dim), FeedForward(dim=dim, dropout=ffn_dropout), ), ) ) # decoder layer self.dec_atn = RotaryCrossAttention( dim=dim, heads=cross_heads, dropout=atn_dropout, dim_head=dim_head, rotate_value=False, ) self.dec_ffn = nn.Sequential( nn.LayerNorm(dim), FeedForward(dim=dim, dropout=ffn_dropout) ) # Output projections + loss self.readout = nn.Linear(dim, dim_out) self.dim = dim
[docs] def forward( self, *, # input sequence input_unit_index: torch.Tensor, input_timestamps: torch.Tensor, input_token_type: torch.Tensor, input_mask: torch.Tensor | None = None, # latent sequence latent_index: torch.Tensor, latent_timestamps: torch.Tensor, # Metadata for queries session_index: torch.Tensor, # output sequence output_timestamps: torch.Tensor, output_mask: torch.Tensor | None = None, unpack_output: bool = False, ) -> torch.Tensor | List[torch.Tensor]: """Forward pass of the POYO model. The model processes input spike sequences through its encoder-processor-decoder architecture to generate task-specific predictions. Args: input_unit_index: Unit indices of shape :math:`(B, N_{in})`. input_timestamps: Spike timestamps of shape :math:`(B, N_{in})`. input_token_type: Token type indices of shape :math:`(B, N_{in})`. input_mask: Boolean mask of shape :math:`(B, N_{in})`. latent_index: Latent token indices of shape :math:`(B, N_{lat})`. latent_timestamps: Latent token timestamps of shape :math:`(B, N_{lat})`. session_index: Session indices of shape :math:`(B,)`. output_timestamps: Output query timestamps of shape :math:`(B, N_{out})`. output_mask: Boolean mask of shape :math:`(B, N_{out})`. Required when ``unpack_output=True``. unpack_output: If ``False``, returns a padded tensor of shape :math:`(B, N_{out}, D_{out})`; use ``output_mask`` to index valid entries. If ``True``, returns a list of :math:`B` tensors each of shape :math:`(N_{out,i}, D_{out})` containing only the valid outputs. Returns: Shape :math:`(B, N_{out}, D_{out})` when ``unpack_output=False``, or a list of :math:`B` tensors of shape :math:`(N_{out,i}, D_{out})` when ``unpack_output=True``. """ if self.unit_emb.is_lazy(): raise ValueError( "Unit vocabulary has not been initialized, please use " "`model.unit_emb.initialize_vocab(unit_ids)`" ) if self.session_emb.is_lazy(): raise ValueError( "Session vocabulary has not been initialized, please use " "`model.session_emb.initialize_vocab(session_ids)`" ) # input inputs = self.unit_emb(input_unit_index) + self.token_type_emb(input_token_type) input_timestamp_emb = self.rotary_emb(input_timestamps) # latents latents = self.latent_emb(latent_index) latent_timestamp_emb = self.rotary_emb(latent_timestamps) # outputs output_queries = self.session_emb(session_index).unsqueeze(1) output_queries = output_queries.expand(-1, output_timestamps.size(1), -1) output_timestamp_emb = self.rotary_emb(output_timestamps) # encode latents = latents + self.enc_atn( latents, inputs, latent_timestamp_emb, input_timestamp_emb, input_mask, ) latents = latents + self.enc_ffn(latents) # process for self_attn, self_ff in self.proc_layers: latents = latents + self.dropout(self_attn(latents, latent_timestamp_emb)) latents = latents + self.dropout(self_ff(latents)) # decode output_queries = output_queries + self.dec_atn( output_queries, latents, output_timestamp_emb, latent_timestamp_emb, ) output_latents = output_queries + self.dec_ffn(output_queries) output = self.readout(output_latents) if unpack_output and output_mask is None: raise ValueError("output_mask is required when unpack_output=True") if unpack_output: output = [output[b][output_mask[b]] for b in range(output.size(0))] return output
[docs] def tokenize(self, data: Data) -> Dict: r"""Tokenizer used to tokenize Data for the POYO model. This tokenizer can be called as a transform. If you are applying multiple transforms, make sure to apply this one last. This code runs on CPU. Do not access GPU tensors inside this function. """ # context window start, end = 0, self.sequence_length ### prepare input unit_ids = data.units.id spike_unit_index = data.spikes.unit_index spike_timestamps = data.spikes.timestamps # create start and end tokens for each unit ( se_token_type_index, se_unit_index, se_timestamps, ) = create_start_end_unit_tokens(unit_ids, start, end) # append start and end tokens to the spike sequence spike_token_type_index = np.concatenate( [se_token_type_index, np.zeros_like(spike_unit_index)] ) spike_unit_index = np.concatenate([se_unit_index, spike_unit_index]) spike_timestamps = np.concatenate([se_timestamps, spike_timestamps]) # unit_index is relative to the recording, so we want it to map it to # the global unit index local_to_global_map = np.array(self.unit_emb.tokenizer(unit_ids)) spike_unit_index = local_to_global_map[spike_unit_index] ### prepare latents latent_index, latent_timestamps = create_linspace_latent_tokens( start, end, step=self.latent_step, num_latents_per_step=self.num_latents_per_step, ) # create session index for output session_index = self.session_emb.tokenizer(data.session.id) data_dict = dict( # input sequence (keys/values for the encoder) input_unit_index=pad8(spike_unit_index), input_timestamps=pad8(spike_timestamps), input_token_type=pad8(spike_token_type_index), input_mask=track_mask8(spike_unit_index), # latent sequence latent_index=latent_index, latent_timestamps=latent_timestamps, # metadata needed for decoder queries session_index=session_index, ) return data_dict
def _validate_params(self, sequence_length, latent_step): r"""Ensure: sequence_length, and latent_step are floating point numbers greater than zero. And sequence_length is a multiple of latent_step. """ if not isinstance(sequence_length, float): raise ValueError("sequence_length must be a float") if not sequence_length > 0: raise ValueError("sequence_length must be greater than 0") self.sequence_length = sequence_length if not isinstance(latent_step, float): raise ValueError("latent_step must be a float") if not latent_step > 0: raise ValueError("latent_step must be greater than 0") self.latent_step = latent_step # check if sequence_length is a multiple of latent_step if abs(sequence_length % latent_step) > 1e-10: logging.warning( f"sequence_length ({sequence_length}) is not a multiple of latent_step " f"({latent_step}). This is a simple warning, and this behavior is allowed." )
[docs] def init_vocabs(self, dataset: Dataset): """Initializes model's unit_emb and session_emb vocabularies. Args: dataset: A :class:`Dataset` with :class:`SpikingDatasetMixin`. """ if hasattr(dataset, "get_unit_ids") and inspect.ismethod(dataset.get_unit_ids): unit_ids = dataset.get_unit_ids() else: raise ValueError( "Could not call method ``get_unit_ids()`` on input ``dataset``." " Perhaps this is not a spiking dataset?" " Consider adding ``SpikingDatasetMixin`` to your Dataset class, which would" " provide this method." ) self.unit_emb.initialize_vocab(unit_ids) self.session_emb.initialize_vocab(dataset.recording_ids)
[docs] @classmethod def load_pretrained(cls, checkpoint_path: str | Path) -> "POYO": """Load a pretrained POYO model from a checkpoint file. Args: checkpoint_path: Path to the checkpoint file containing model weights and hyperparameters. Returns: An instance of the POYO model with weights loaded from the checkpoint. """ # For now, we are loading from the checkpoint generated using the official # Lightning trainer ckpt = torch.load(checkpoint_path, map_location="cpu", weights_only=False) hparams = ckpt["hyper_parameters"]["model"] hparams.pop("_target_", None) state_dict = {k.replace("model.", ""): v for k, v in ckpt["state_dict"].items()} # Infer `dim_out` from shape of readout weights dim_out = state_dict["readout.weight"].size(0) model = cls(**hparams, dim_out=dim_out) model.load_state_dict(state_dict) return model
class _GEGLU(nn.Module): """Gated Gaussian Error Linear Unit (GEGLU) activation function, as introduced in the paper "GLU Variants Improve Transformer" (https://arxiv.org/abs/2002.05202). """ def forward(self, x): x, gates = x.chunk(2, dim=-1) return x * F.gelu(gates) class FeedForward(nn.Module): """A feed-forward network with GEGLU activation. Args: dim: Input and output dimension mult: Multiplier for hidden dimension. Defaults to 4 dropout: Dropout probability. Defaults to 0.2 """ def __init__(self, dim: int, mult: int = 4, dropout: float = 0.2): super().__init__() self.net = nn.Sequential( nn.Linear(dim, dim * mult * 2), _GEGLU(), nn.Dropout(p=dropout), nn.Linear(dim * mult, dim), ) def forward(self, x): return self.net(x)