Source code for torch_brain.models.poyo

import inspect
from typing import Callable, Dict, Iterable, List, Optional, Tuple, Union
import logging

from pathlib import Path
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchtyping import TensorType
from temporaldata import Data

from torch_brain.dataset import Dataset
from torch_brain.data import pad8, track_mask8
from torch_brain.nn import (
    Embedding,
    InfiniteVocabEmbedding,
    RotaryCrossAttention,
    RotarySelfAttention,
    RotaryTimeEmbedding,
)

from torch_brain.utils import (
    create_linspace_latent_tokens,
    create_start_end_unit_tokens,
)


[docs] class POYO(nn.Module): """POYO model from `Azabou et al. 2023, A Unified, Scalable Framework for Neural Population Decoding <https://arxiv.org/abs/2310.16046>`_. POYO is a transformer-based model for neural decoding from electrophysiological recordings. 1. Input tokens are constructed by combining unit embeddings, token type embeddings, and time embeddings for each spike in the sequence. 2. The input sequence is compressed using cross-attention, where learnable latent tokens (each with an associated timestamp) attend to the input tokens. 3. The compressed latent token representations undergo further refinement through multiple self-attention processing layers. 4. Query tokens are constructed for the desired outputs by combining session embeddings, and output timestamps. 5. These query tokens attend to the processed latent representations through cross-attention, producing outputs in the model's dimensional space (dim). 6. Finally, a task-specific linear layer maps the outputs from the model dimension to the appropriate output dimension. Args: sequence_length: Maximum duration of the input spike sequence (in seconds) latent_step: Timestep of the latent grid (in seconds) num_latents_per_step: Number of unique latent tokens (repeated at every latent step) dim: Hidden dimension of the model dim_out: Output dimension of the model depth: Number of processing layers (self-attentions in the latent space) dim_head: Dimension of each attention head cross_heads: Number of attention heads used in a cross-attention layer self_heads: Number of attention heads used in a self-attention layer ffn_dropout: Dropout rate for feed-forward networks lin_dropout: Dropout rate for linear layers atn_dropout: Dropout rate for attention emb_init_scale: Scale for embedding initialization t_min: Minimum timestamp resolution for rotary embeddings t_max: Maximum timestamp resolution for rotary embeddings """ def __init__( self, *, sequence_length: float, latent_step: float, num_latents_per_step: int = 64, dim_out: int, dim: int = 512, depth: int = 2, dim_head: int = 64, cross_heads: int = 1, self_heads: int = 8, ffn_dropout: float = 0.2, lin_dropout: float = 0.4, atn_dropout: float = 0.0, emb_init_scale: float = 0.02, t_min: float = 1e-4, t_max: float = 2.0627, ): super().__init__() self._validate_params(sequence_length, latent_step) self.sequence_length = sequence_length self.latent_step = latent_step self.num_latents_per_step = num_latents_per_step # embeddings self.unit_emb = InfiniteVocabEmbedding(dim, init_scale=emb_init_scale) self.session_emb = InfiniteVocabEmbedding(dim, init_scale=emb_init_scale) self.token_type_emb = Embedding(4, dim, init_scale=emb_init_scale) self.latent_emb = Embedding( num_latents_per_step, dim, init_scale=emb_init_scale ) self.rotary_emb = RotaryTimeEmbedding( head_dim=dim_head, rotate_dim=dim_head // 2, t_min=t_min, t_max=t_max, ) self.dropout = nn.Dropout(p=lin_dropout) # encoder layer self.enc_atn = RotaryCrossAttention( dim=dim, heads=cross_heads, dropout=atn_dropout, dim_head=dim_head, rotate_value=True, ) self.enc_ffn = nn.Sequential( nn.LayerNorm(dim), FeedForward(dim=dim, dropout=ffn_dropout) ) # process layers self.proc_layers = nn.ModuleList([]) for i in range(depth): self.proc_layers.append( nn.Sequential( RotarySelfAttention( dim=dim, heads=self_heads, dropout=atn_dropout, dim_head=dim_head, rotate_value=True, ), nn.Sequential( nn.LayerNorm(dim), FeedForward(dim=dim, dropout=ffn_dropout), ), ) ) # decoder layer self.dec_atn = RotaryCrossAttention( dim=dim, heads=cross_heads, dropout=atn_dropout, dim_head=dim_head, rotate_value=False, ) self.dec_ffn = nn.Sequential( nn.LayerNorm(dim), FeedForward(dim=dim, dropout=ffn_dropout) ) # Output projections + loss self.readout = nn.Linear(dim, dim_out) self.dim = dim
[docs] def forward( self, *, # input sequence input_unit_index: TensorType["batch", "n_in", int], input_timestamps: TensorType["batch", "n_in", float], input_token_type: TensorType["batch", "n_in", int], input_mask: Optional[TensorType["batch", "n_in", bool]] = None, # latent sequence latent_index: TensorType["batch", "n_latent", int], latent_timestamps: TensorType["batch", "n_latent", float], # Metadata for queries session_index: TensorType["batch", int], # output sequence output_timestamps: TensorType["batch", "n_out", float], output_mask: Optional[TensorType["batch", "n_out", bool]] = None, unpack_output: bool = False, ) -> Union[ TensorType["batch", "n_out", "dim_out", float], List[TensorType[..., "dim_out", float]], ]: """Forward pass of the POYO model. The model processes input spike sequences through its encoder-processor-decoder architecture to generate task-specific predictions. Args: input_unit_index: Indices of input units input_timestamps: Timestamps of input spikes input_token_type: Type of input tokens input_mask: Mask for input sequence latent_index: Indices for latent tokens latent_timestamps: Timestamps for latent tokens session_index: Index of the recording session output_timestamps: Timestamps for output predictions output_mask: A mask of the same size as output_timestamps. True implies that particular timestamp is a valid query for POYO. This is required iff `unpack_output` is set to True. unpack_output: If False, this function will return a padded tensor of shape (batch size, num of max output queries in batch, `dim_out`). In this case you have to use `output_mask` externally to only look at valid outputs. If True, this will return a list of Tensors: the length of the list is equal to batch size, the shape of i^th Tensor is (num of valid output queries for i^th sample, `d_out`). Returns: A :class:`torch.Tensor` of shape `(batch, n_out, dim_out)` containing the predicted outputs corresponding to `output_timestamps`. """ if self.unit_emb.is_lazy(): raise ValueError( "Unit vocabulary has not been initialized, please use " "`model.unit_emb.initialize_vocab(unit_ids)`" ) if self.session_emb.is_lazy(): raise ValueError( "Session vocabulary has not been initialized, please use " "`model.session_emb.initialize_vocab(session_ids)`" ) # input inputs = self.unit_emb(input_unit_index) + self.token_type_emb(input_token_type) input_timestamp_emb = self.rotary_emb(input_timestamps) # latents latents = self.latent_emb(latent_index) latent_timestamp_emb = self.rotary_emb(latent_timestamps) # outputs output_queries = self.session_emb(session_index).unsqueeze(1) output_queries = output_queries.expand(-1, output_timestamps.size(1), -1) output_timestamp_emb = self.rotary_emb(output_timestamps) # encode latents = latents + self.enc_atn( latents, inputs, latent_timestamp_emb, input_timestamp_emb, input_mask, ) latents = latents + self.enc_ffn(latents) # process for self_attn, self_ff in self.proc_layers: latents = latents + self.dropout(self_attn(latents, latent_timestamp_emb)) latents = latents + self.dropout(self_ff(latents)) # decode output_queries = output_queries + self.dec_atn( output_queries, latents, output_timestamp_emb, latent_timestamp_emb, ) output_latents = output_queries + self.dec_ffn(output_queries) output = self.readout(output_latents) if unpack_output: output = [output[b][output_mask[b]] for b in range(output.size(0))] return output
[docs] def tokenize(self, data: Data) -> Dict: r"""Tokenizer used to tokenize Data for the POYO model. This tokenizer can be called as a transform. If you are applying multiple transforms, make sure to apply this one last. This code runs on CPU. Do not access GPU tensors inside this function. """ # context window start, end = 0, self.sequence_length ### prepare input unit_ids = data.units.id spike_unit_index = data.spikes.unit_index spike_timestamps = data.spikes.timestamps # create start and end tokens for each unit ( se_token_type_index, se_unit_index, se_timestamps, ) = create_start_end_unit_tokens(unit_ids, start, end) # append start and end tokens to the spike sequence spike_token_type_index = np.concatenate( [se_token_type_index, np.zeros_like(spike_unit_index)] ) spike_unit_index = np.concatenate([se_unit_index, spike_unit_index]) spike_timestamps = np.concatenate([se_timestamps, spike_timestamps]) # unit_index is relative to the recording, so we want it to map it to # the global unit index local_to_global_map = np.array(self.unit_emb.tokenizer(unit_ids)) spike_unit_index = local_to_global_map[spike_unit_index] ### prepare latents latent_index, latent_timestamps = create_linspace_latent_tokens( start, end, step=self.latent_step, num_latents_per_step=self.num_latents_per_step, ) # create session index for output session_index = self.session_emb.tokenizer(data.session.id) data_dict = dict( # input sequence (keys/values for the encoder) input_unit_index=pad8(spike_unit_index), input_timestamps=pad8(spike_timestamps), input_token_type=pad8(spike_token_type_index), input_mask=track_mask8(spike_unit_index), # latent sequence latent_index=latent_index, latent_timestamps=latent_timestamps, # metadata needed for decoder queries session_index=session_index, ) return data_dict
def _validate_params(self, sequence_length, latent_step): r"""Ensure: sequence_length, and latent_step are floating point numbers greater than zero. And sequence_length is a multiple of latent_step. """ if not isinstance(sequence_length, float): raise ValueError("sequence_length must be a float") if not sequence_length > 0: raise ValueError("sequence_length must be greater than 0") self.sequence_length = sequence_length if not isinstance(latent_step, float): raise ValueError("latent_step must be a float") if not latent_step > 0: raise ValueError("latent_step must be greater than 0") self.latent_step = latent_step # check if sequence_length is a multiple of latent_step if abs(sequence_length % latent_step) > 1e-10: logging.warning( f"sequence_length ({sequence_length}) is not a multiple of latent_step " f"({latent_step}). This is a simple warning, and this behavior is allowed." )
[docs] def init_vocabs(self, dataset: Dataset): """Initializes model's unit_emb and session_emb vocabularies. Args: dataset: A :class:`Dataset` with :class:`SpikingDatasetMixin`. """ if hasattr(dataset, "get_unit_ids") and inspect.ismethod(dataset.get_unit_ids): unit_ids = dataset.get_unit_ids() else: raise ValueError( "Could not call method ``get_unit_ids()`` on input ``dataset``." " Perhaps this is not a spiking dataset?" " Consider adding ``SpikingDatasetMixin`` to your Dataset class, which would" " provide this method." ) self.unit_emb.initialize_vocab(unit_ids) self.session_emb.initialize_vocab(dataset.recording_ids)
[docs] @classmethod def load_pretrained(cls, checkpoint_path: str | Path) -> "POYO": """Load a pretrained POYO model from a checkpoint file. Args: checkpoint_path: Path to the checkpoint file containing model weights and hyperparameters. Returns: An instance of the POYO model with weights loaded from the checkpoint. """ # For now, we are loading from the checkpoint generated using the official # Lightning trainer ckpt = torch.load(checkpoint_path, map_location="cpu", weights_only=False) hparams = ckpt["hyper_parameters"]["model"] hparams.pop("_target_", None) state_dict = {k.replace("model.", ""): v for k, v in ckpt["state_dict"].items()} # Infer `dim_out` from shape of readout weights dim_out = state_dict["readout.weight"].size(0) model = cls(**hparams, dim_out=dim_out) model.load_state_dict(state_dict) return model
class _GEGLU(nn.Module): """Gated Gaussian Error Linear Unit (GEGLU) activation function, as introduced in the paper "GLU Variants Improve Transformer" (https://arxiv.org/abs/2002.05202). """ def forward(self, x): x, gates = x.chunk(2, dim=-1) return x * F.gelu(gates) class FeedForward(nn.Module): """A feed-forward network with GEGLU activation. Args: dim (int): Input and output dimension mult (int, optional): Multiplier for hidden dimension. Defaults to 4 dropout (float, optional): Dropout probability. Defaults to 0.2 """ def __init__(self, dim, mult=4, dropout=0.2): super().__init__() self.net = nn.Sequential( nn.Linear(dim, dim * mult * 2), _GEGLU(), nn.Dropout(p=dropout), nn.Linear(dim * mult, dim), ) def forward(self, x): return self.net(x)