Source code for torch_brain.models.calcium_poyo_plus

from typing import Dict, List, Optional, Tuple
import logging

import numpy as np
from torch import cat, tensor, int64
from einops import rearrange, repeat
import torch.nn as nn
from torchtyping import TensorType
from temporaldata import Data

from torch_brain.data import chain, pad8, track_mask8
from torch_brain.nn import (
    Embedding,
    FeedForward,
    InfiniteVocabEmbedding,
    MultitaskReadout,
    RotaryCrossAttention,
    RotarySelfAttention,
    RotaryTimeEmbedding,
    prepare_for_multitask_readout,
)
from torch_brain.registry import ModalitySpec, MODALITY_REGISTRY

from torch_brain.utils import (
    create_linspace_latent_tokens,
    create_start_end_unit_tokens,
)


[docs] class CalciumPOYOPlus(nn.Module): """Calcium POYO+ model from `Azabou et al. 2025, Multi-session, multi-task neural decoding from distinct cell-types and brain regions <https://openreview.net/forum?id=IuU0wcO0mo>`_. Calcium POYO+ is a transformer-based model for neural decoding from calcium imaging recordings. It extends the POYO+ architecture with a calcium value map. """ def __init__( self, *, sequence_length: float, readout_specs: Dict[str, ModalitySpec] = MODALITY_REGISTRY, latent_step: float, num_latents_per_step: int = 64, dim: int = 512, depth: int = 2, dim_head: int = 64, cross_heads: int = 1, self_heads: int = 8, ffn_dropout: float = 0.2, lin_dropout: float = 0.4, atn_dropout: float = 0.0, emb_init_scale: float = 0.02, t_min: float = 1e-4, t_max: float = 2.0627, ): super().__init__() self._validate_params(sequence_length, latent_step) self.latent_step = latent_step self.num_latents_per_step = num_latents_per_step self.sequence_length = sequence_length self.readout_specs = readout_specs # input value map self.input_value_map = nn.Linear(1, dim // 2) nn.init.trunc_normal_(self.input_value_map.weight, 0, emb_init_scale) nn.init.zeros_(self.input_value_map.bias) # ^ initialize weights for faster convergence # embeddings self.unit_emb = InfiniteVocabEmbedding(dim // 2, init_scale=emb_init_scale) self.session_emb = InfiniteVocabEmbedding(dim, init_scale=emb_init_scale) self.token_type_emb = Embedding(4, dim, init_scale=emb_init_scale) self.task_emb = Embedding( len(readout_specs) + 1, dim, init_scale=emb_init_scale ) self.latent_emb = Embedding( num_latents_per_step, dim, init_scale=emb_init_scale ) self.rotary_emb = RotaryTimeEmbedding( head_dim=dim_head, rotate_dim=dim_head // 2, t_min=t_min, t_max=t_max, ) self.dropout = nn.Dropout(p=lin_dropout) # encoder layer self.enc_atn = RotaryCrossAttention( dim=dim, heads=cross_heads, dropout=atn_dropout, dim_head=dim_head, rotate_value=True, ) self.enc_ffn = nn.Sequential( nn.LayerNorm(dim), FeedForward(dim=dim, dropout=ffn_dropout) ) # process layers self.proc_layers = nn.ModuleList([]) for i in range(depth): self.proc_layers.append( nn.Sequential( RotarySelfAttention( dim=dim, heads=self_heads, dropout=atn_dropout, dim_head=dim_head, rotate_value=True, ), nn.Sequential( nn.LayerNorm(dim), FeedForward(dim=dim, dropout=ffn_dropout), ), ) ) # decoder layer self.dec_atn = RotaryCrossAttention( dim=dim, heads=cross_heads, dropout=atn_dropout, dim_head=dim_head, rotate_value=False, ) self.dec_ffn = nn.Sequential( nn.LayerNorm(dim), FeedForward(dim=dim, dropout=ffn_dropout) ) # Output projections + loss self.readout = MultitaskReadout( dim=dim, readout_specs=readout_specs, ) self.dim = dim
[docs] def forward( self, *, # input sequence input_unit_index: TensorType["batch", "n_in", int], input_timestamps: TensorType["batch", "n_in", float], input_values: TensorType["batch", "n_in", float], input_mask: Optional[TensorType["batch", "n_in", bool]] = None, # latent sequence latent_index: TensorType["batch", "n_latent", int], latent_timestamps: TensorType["batch", "n_latent", float], # output sequence output_session_index: TensorType["batch", "n_out", int], output_timestamps: TensorType["batch", "n_out", float], output_decoder_index: TensorType["batch", "n_out", int], unpack_output: bool = False, ) -> Tuple[List[Dict[str, TensorType["*nqueries", "*nchannelsout"]]]]: """Forward pass of the POYO+ model. The model processes input spike sequences through its encoder-processor-decoder architecture to generate task-specific predictions. Args: input_unit_index: Indices of input units input_timestamps: Timestamps of input spikes input_values: Calcium values of input sequence input_mask: Mask for input sequence latent_index: Indices for latent tokens latent_timestamps: Timestamps for latent tokens session_index: Index of the recording session output_timestamps: Timestamps for output predictions output_decoder_index: Indices indicating which decoder to use output_batch_index: Optional batch indices for outputs output_values: Ground truth values for supervised training output_weights: Optional weights for loss computation Returns: Tuple containing: - A list of dictionaries, each containing the predicted outputs for a given task - Total loss value - Dictionary of per-task losses """ if self.unit_emb.is_lazy(): raise ValueError( "Unit vocabulary has not been initialized, please use " "`model.unit_emb.initialize_vocab(unit_ids)`" ) if self.session_emb.is_lazy(): raise ValueError( "Session vocabulary has not been initialized, please use " "`model.session_emb.initialize_vocab(session_ids)`" ) # input inputs = cat( (self.input_value_map(input_values), self.unit_emb(input_unit_index)), dim=-1, ) input_timestamp_emb = self.rotary_emb(input_timestamps) # latents latents = self.latent_emb(latent_index) latent_timestamp_emb = self.rotary_emb(latent_timestamps) # outputs output_queries = self.session_emb(output_session_index) + self.task_emb( output_decoder_index ) output_timestamp_emb = self.rotary_emb(output_timestamps) # encode latents = latents + self.enc_atn( latents, inputs, latent_timestamp_emb, input_timestamp_emb, input_mask, ) latents = latents + self.enc_ffn(latents) # process for self_attn, self_ff in self.proc_layers: latents = latents + self.dropout(self_attn(latents, latent_timestamp_emb)) latents = latents + self.dropout(self_ff(latents)) # decode output_queries = output_queries + self.dec_atn( output_queries, latents, output_timestamp_emb, latent_timestamp_emb, ) output_latents = output_queries + self.dec_ffn(output_queries) # multitask readout layer, each task has a separate linear readout layer output = self.readout( output_embs=output_latents, output_readout_index=output_decoder_index, unpack_output=unpack_output, ) return output
[docs] def tokenize(self, data: Data) -> Dict: r"""Tokenizer used to tokenize Data for the POYO+ model. This tokenizer can be called as a transform. If you are applying multiple transforms, make sure to apply this one last. This code runs on CPU. Do not access GPU tensors inside this function. """ # context window start, end = 0, self.sequence_length ### prepare input calcium_traces = data.calcium_traces unit_ids = data.units.id T, N = calcium_traces.df_over_f.shape input_timestamps = repeat(calcium_traces.timestamps, "T -> (T N)", T=T, N=N) ### prepare calcium values T, N = calcium_traces.df_over_f.shape input_values = rearrange(calcium_traces.df_over_f, "T N -> (T N) 1", T=T, N=N) # input unit indices local_to_global_map = np.array(self.unit_emb.tokenizer(unit_ids)) input_unit_index = [local_to_global_map[i] for i in range(len(unit_ids))] input_unit_index = tensor(input_unit_index, dtype=int64) input_unit_index = repeat(input_unit_index, "N -> (T N)", T=T, N=N) ### prepare latents latent_index, latent_timestamps = create_linspace_latent_tokens( start, end, step=self.latent_step, num_latents_per_step=self.num_latents_per_step, ) ### prepare outputs session_index = self.session_emb.tokenizer(data.session.id) ( output_timestamps, output_values, output_task_index, output_weights, output_eval_mask, ) = prepare_for_multitask_readout( data, self.readout_specs, ) session_index = np.repeat(session_index, len(output_timestamps)) data_dict = { "model_inputs": { # input sequence "input_unit_index": pad8(input_unit_index), "input_timestamps": pad8(input_timestamps), "input_values": pad8(input_values), "input_mask": track_mask8(input_unit_index), # latent sequence "latent_index": latent_index, "latent_timestamps": latent_timestamps, # output sequence "output_session_index": pad8(session_index), "output_timestamps": pad8(output_timestamps), "output_decoder_index": pad8(output_task_index), }, # ground truth targets "target_values": chain(output_values, allow_missing_keys=True), "target_weights": chain(output_weights, allow_missing_keys=True), # extra fields for evaluation "session_id": data.session.id, "absolute_start": data.absolute_start, "eval_mask": chain(output_eval_mask, allow_missing_keys=True), } return data_dict
def _validate_params(self, sequence_length, latent_step): r"""Ensure: sequence_length, and latent_step are floating point numbers greater than zero. And sequence_length is a multiple of latent_step. """ if not isinstance(sequence_length, float): raise ValueError("sequence_length must be a float") if not sequence_length > 0: raise ValueError("sequence_length must be greater than 0") self.sequence_length = sequence_length if not isinstance(latent_step, float): raise ValueError("latent_step must be a float") if not latent_step > 0: raise ValueError("latent_step must be greater than 0") self.latent_step = latent_step # check if sequence_length is a multiple of latent_step if abs(sequence_length % latent_step) > 1e-10: logging.warning( f"sequence_length ({sequence_length}) is not a multiple of latent_step " f"({latent_step}). This is a simple warning, and this behavior is allowed." )