from typing import Dict, List, Optional, Tuple
import logging
import numpy as np
from torch import cat, tensor, int64
from einops import rearrange, repeat
import torch.nn as nn
from torchtyping import TensorType
from temporaldata import Data
from torch_brain.data import chain, pad8, track_mask8
from torch_brain.nn import (
Embedding,
FeedForward,
InfiniteVocabEmbedding,
MultitaskReadout,
RotaryCrossAttention,
RotarySelfAttention,
RotaryTimeEmbedding,
prepare_for_multitask_readout,
)
from torch_brain.registry import ModalitySpec, MODALITY_REGISTRY
from torch_brain.utils import (
create_linspace_latent_tokens,
create_start_end_unit_tokens,
)
[docs]
class CalciumPOYOPlus(nn.Module):
"""Calcium POYO+ model from `Azabou et al. 2025, Multi-session, multi-task neural decoding from distinct cell-types and brain regions <https://openreview.net/forum?id=IuU0wcO0mo>`_.
Calcium POYO+ is a transformer-based model for neural decoding from calcium imaging recordings.
It extends the POYO+ architecture with a calcium value map.
"""
def __init__(
self,
*,
sequence_length: float,
readout_specs: Dict[str, ModalitySpec] = MODALITY_REGISTRY,
latent_step: float,
num_latents_per_step: int = 64,
dim: int = 512,
depth: int = 2,
dim_head: int = 64,
cross_heads: int = 1,
self_heads: int = 8,
ffn_dropout: float = 0.2,
lin_dropout: float = 0.4,
atn_dropout: float = 0.0,
emb_init_scale: float = 0.02,
t_min: float = 1e-4,
t_max: float = 2.0627,
):
super().__init__()
self._validate_params(sequence_length, latent_step)
self.latent_step = latent_step
self.num_latents_per_step = num_latents_per_step
self.sequence_length = sequence_length
self.readout_specs = readout_specs
# input value map
self.input_value_map = nn.Linear(1, dim // 2)
nn.init.trunc_normal_(self.input_value_map.weight, 0, emb_init_scale)
nn.init.zeros_(self.input_value_map.bias)
# ^ initialize weights for faster convergence
# embeddings
self.unit_emb = InfiniteVocabEmbedding(dim // 2, init_scale=emb_init_scale)
self.session_emb = InfiniteVocabEmbedding(dim, init_scale=emb_init_scale)
self.token_type_emb = Embedding(4, dim, init_scale=emb_init_scale)
self.task_emb = Embedding(
len(readout_specs) + 1, dim, init_scale=emb_init_scale
)
self.latent_emb = Embedding(
num_latents_per_step, dim, init_scale=emb_init_scale
)
self.rotary_emb = RotaryTimeEmbedding(
head_dim=dim_head,
rotate_dim=dim_head // 2,
t_min=t_min,
t_max=t_max,
)
self.dropout = nn.Dropout(p=lin_dropout)
# encoder layer
self.enc_atn = RotaryCrossAttention(
dim=dim,
heads=cross_heads,
dropout=atn_dropout,
dim_head=dim_head,
rotate_value=True,
)
self.enc_ffn = nn.Sequential(
nn.LayerNorm(dim), FeedForward(dim=dim, dropout=ffn_dropout)
)
# process layers
self.proc_layers = nn.ModuleList([])
for i in range(depth):
self.proc_layers.append(
nn.Sequential(
RotarySelfAttention(
dim=dim,
heads=self_heads,
dropout=atn_dropout,
dim_head=dim_head,
rotate_value=True,
),
nn.Sequential(
nn.LayerNorm(dim),
FeedForward(dim=dim, dropout=ffn_dropout),
),
)
)
# decoder layer
self.dec_atn = RotaryCrossAttention(
dim=dim,
heads=cross_heads,
dropout=atn_dropout,
dim_head=dim_head,
rotate_value=False,
)
self.dec_ffn = nn.Sequential(
nn.LayerNorm(dim), FeedForward(dim=dim, dropout=ffn_dropout)
)
# Output projections + loss
self.readout = MultitaskReadout(
dim=dim,
readout_specs=readout_specs,
)
self.dim = dim
[docs]
def forward(
self,
*,
# input sequence
input_unit_index: TensorType["batch", "n_in", int],
input_timestamps: TensorType["batch", "n_in", float],
input_values: TensorType["batch", "n_in", float],
input_mask: Optional[TensorType["batch", "n_in", bool]] = None,
# latent sequence
latent_index: TensorType["batch", "n_latent", int],
latent_timestamps: TensorType["batch", "n_latent", float],
# output sequence
output_session_index: TensorType["batch", "n_out", int],
output_timestamps: TensorType["batch", "n_out", float],
output_decoder_index: TensorType["batch", "n_out", int],
unpack_output: bool = False,
) -> Tuple[List[Dict[str, TensorType["*nqueries", "*nchannelsout"]]]]:
"""Forward pass of the POYO+ model.
The model processes input spike sequences through its encoder-processor-decoder
architecture to generate task-specific predictions.
Args:
input_unit_index: Indices of input units
input_timestamps: Timestamps of input spikes
input_values: Calcium values of input sequence
input_mask: Mask for input sequence
latent_index: Indices for latent tokens
latent_timestamps: Timestamps for latent tokens
session_index: Index of the recording session
output_timestamps: Timestamps for output predictions
output_decoder_index: Indices indicating which decoder to use
output_batch_index: Optional batch indices for outputs
output_values: Ground truth values for supervised training
output_weights: Optional weights for loss computation
Returns:
Tuple containing:
- A list of dictionaries, each containing the predicted outputs for a
given task
- Total loss value
- Dictionary of per-task losses
"""
if self.unit_emb.is_lazy():
raise ValueError(
"Unit vocabulary has not been initialized, please use "
"`model.unit_emb.initialize_vocab(unit_ids)`"
)
if self.session_emb.is_lazy():
raise ValueError(
"Session vocabulary has not been initialized, please use "
"`model.session_emb.initialize_vocab(session_ids)`"
)
# input
inputs = cat(
(self.input_value_map(input_values), self.unit_emb(input_unit_index)),
dim=-1,
)
input_timestamp_emb = self.rotary_emb(input_timestamps)
# latents
latents = self.latent_emb(latent_index)
latent_timestamp_emb = self.rotary_emb(latent_timestamps)
# outputs
output_queries = self.session_emb(output_session_index) + self.task_emb(
output_decoder_index
)
output_timestamp_emb = self.rotary_emb(output_timestamps)
# encode
latents = latents + self.enc_atn(
latents,
inputs,
latent_timestamp_emb,
input_timestamp_emb,
input_mask,
)
latents = latents + self.enc_ffn(latents)
# process
for self_attn, self_ff in self.proc_layers:
latents = latents + self.dropout(self_attn(latents, latent_timestamp_emb))
latents = latents + self.dropout(self_ff(latents))
# decode
output_queries = output_queries + self.dec_atn(
output_queries,
latents,
output_timestamp_emb,
latent_timestamp_emb,
)
output_latents = output_queries + self.dec_ffn(output_queries)
# multitask readout layer, each task has a separate linear readout layer
output = self.readout(
output_embs=output_latents,
output_readout_index=output_decoder_index,
unpack_output=unpack_output,
)
return output
[docs]
def tokenize(self, data: Data) -> Dict:
r"""Tokenizer used to tokenize Data for the POYO+ model.
This tokenizer can be called as a transform. If you are applying multiple
transforms, make sure to apply this one last.
This code runs on CPU. Do not access GPU tensors inside this function.
"""
# context window
start, end = 0, self.sequence_length
### prepare input
calcium_traces = data.calcium_traces
unit_ids = data.units.id
T, N = calcium_traces.df_over_f.shape
input_timestamps = repeat(calcium_traces.timestamps, "T -> (T N)", T=T, N=N)
### prepare calcium values
T, N = calcium_traces.df_over_f.shape
input_values = rearrange(calcium_traces.df_over_f, "T N -> (T N) 1", T=T, N=N)
# input unit indices
local_to_global_map = np.array(self.unit_emb.tokenizer(unit_ids))
input_unit_index = [local_to_global_map[i] for i in range(len(unit_ids))]
input_unit_index = tensor(input_unit_index, dtype=int64)
input_unit_index = repeat(input_unit_index, "N -> (T N)", T=T, N=N)
### prepare latents
latent_index, latent_timestamps = create_linspace_latent_tokens(
start,
end,
step=self.latent_step,
num_latents_per_step=self.num_latents_per_step,
)
### prepare outputs
session_index = self.session_emb.tokenizer(data.session.id)
(
output_timestamps,
output_values,
output_task_index,
output_weights,
output_eval_mask,
) = prepare_for_multitask_readout(
data,
self.readout_specs,
)
session_index = np.repeat(session_index, len(output_timestamps))
data_dict = {
"model_inputs": {
# input sequence
"input_unit_index": pad8(input_unit_index),
"input_timestamps": pad8(input_timestamps),
"input_values": pad8(input_values),
"input_mask": track_mask8(input_unit_index),
# latent sequence
"latent_index": latent_index,
"latent_timestamps": latent_timestamps,
# output sequence
"output_session_index": pad8(session_index),
"output_timestamps": pad8(output_timestamps),
"output_decoder_index": pad8(output_task_index),
},
# ground truth targets
"target_values": chain(output_values, allow_missing_keys=True),
"target_weights": chain(output_weights, allow_missing_keys=True),
# extra fields for evaluation
"session_id": data.session.id,
"absolute_start": data.absolute_start,
"eval_mask": chain(output_eval_mask, allow_missing_keys=True),
}
return data_dict
def _validate_params(self, sequence_length, latent_step):
r"""Ensure: sequence_length, and latent_step are floating point numbers greater
than zero. And sequence_length is a multiple of latent_step.
"""
if not isinstance(sequence_length, float):
raise ValueError("sequence_length must be a float")
if not sequence_length > 0:
raise ValueError("sequence_length must be greater than 0")
self.sequence_length = sequence_length
if not isinstance(latent_step, float):
raise ValueError("latent_step must be a float")
if not latent_step > 0:
raise ValueError("latent_step must be greater than 0")
self.latent_step = latent_step
# check if sequence_length is a multiple of latent_step
if abs(sequence_length % latent_step) > 1e-10:
logging.warning(
f"sequence_length ({sequence_length}) is not a multiple of latent_step "
f"({latent_step}). This is a simple warning, and this behavior is allowed."
)