torch_brain.models¶
Brain decoders¶
- class POYO(*, sequence_length, readout_spec, latent_step, num_latents_per_step=64, dim=512, depth=2, dim_head=64, cross_heads=1, self_heads=8, ffn_dropout=0.2, lin_dropout=0.4, atn_dropout=0.0, emb_init_scale=0.02, t_min=0.0001, t_max=4.0)[source]¶
Bases:
Module
POYO model from Azabou et al. 2023, A Unified, Scalable Framework for Neural Population Decoding.
POYO is a transformer-based model for neural decoding from electrophysiological recordings.
- Input tokens are constructed by combining unit embeddings, token type embeddings,
and time embeddings for each spike in the sequence.
- The input sequence is compressed using cross-attention, where learnable latent
tokens (each with an associated timestamp) attend to the input tokens.
- The compressed latent token representations undergo further refinement through
multiple self-attention processing layers.
- Query tokens are constructed for the desired outputs by combining session
embeddings, and output timestamps.
- These query tokens attend to the processed latent representations through
cross-attention, producing outputs in the model’s dimensional space (dim).
- Finally, a task-specific linear layer maps the outputs from the model dimension
to the appropriate output dimension.
- Parameters:
sequence_length (
float
) – Maximum duration of the input spike sequence (in seconds)readout_spec (
ModalitySpec
) – Atorch_brain.registry.ModalitySpec
specifying readout propertieslatent_step (
float
) – Timestep of the latent grid (in seconds)num_latents_per_step (
int
) – Number of unique latent tokens (repeated at every latent step)dim (
int
) – Hidden dimension of the modeldepth (
int
) – Number of processing layers (self-attentions in the latent space)dim_head (
int
) – Dimension of each attention headcross_heads (
int
) – Number of attention heads used in a cross-attention layerself_heads (
int
) – Number of attention heads used in a self-attention layerffn_dropout (
float
) – Dropout rate for feed-forward networkslin_dropout (
float
) – Dropout rate for linear layersatn_dropout (
float
) – Dropout rate for attentionemb_init_scale (
float
) – Scale for embedding initializationt_min (
float
) – Minimum timestamp resolution for rotary embeddingst_max (
float
) – Maximum timestamp resolution for rotary embeddings
- forward(*, input_unit_index, input_timestamps, input_token_type, input_mask=None, latent_index, latent_timestamps, output_session_index, output_timestamps, output_mask=None, unpack_output=False)[source]¶
Forward pass of the POYO model.
The model processes input spike sequences through its encoder-processor-decoder architecture to generate task-specific predictions.
- Parameters:
input_unit_index (
Tensor
) – Indices of input unitsinput_timestamps (
Tensor
) – Timestamps of input spikesinput_token_type (
Tensor
) – Type of input tokensinput_mask (
Optional
[Tensor
]) – Mask for input sequencelatent_index (
Tensor
) – Indices for latent tokenslatent_timestamps (
Tensor
) – Timestamps for latent tokensoutput_session_index (
Tensor
) – Index of the recording sessionoutput_timestamps (
Tensor
) – Timestamps for output predictionsoutput_mask (
Optional
[Tensor
]) – A mask of the same size as output_timestamps. True implies that particular timestamp is a valid query for POYO. This is required iff unpack_output is set to True.unpack_output (
bool
) – If False, this function will return a padded tensor of shape (batch size, num of max output queries in batch, dim_out). In this case you have to use output_mask externally to only look at valid outputs. If True, this will return a list of Tensors: the length of the list is equal to batch size, the shape of i^th Tensor is (num of valid output queries for i^th sample, d_out).
- Return type:
- Returns:
A
torch.Tensor
of shape (batch, n_out, dim_out) containing the predicted outputs corresponding to output_timestamps.
- class POYOPlus(*, sequence_length, readout_specs={'arm_velocity_2d': ModalitySpec(id=3, dim=2, type=<DataType.CONTINUOUS: 0>, timestamp_key='behavior.timestamps', value_key='behavior.hand_vel', loss_fn=MSELoss()), 'cursor_position_2d': ModalitySpec(id=2, dim=2, type=<DataType.CONTINUOUS: 0>, timestamp_key='cursor.timestamps', value_key='cursor.pos', loss_fn=MSELoss()), 'cursor_velocity_2d': ModalitySpec(id=1, dim=2, type=<DataType.CONTINUOUS: 0>, timestamp_key='cursor.timestamps', value_key='cursor.vel', loss_fn=MSELoss()), 'drifting_gratings_orientation': ModalitySpec(id=4, dim=8, type=<DataType.MULTINOMIAL: 2>, timestamp_key='drifting_gratings.timestamps', value_key='drifting_gratings.orientation_id', loss_fn=CrossEntropyLoss()), 'drifting_gratings_temporal_frequency': ModalitySpec(id=5, dim=5, type=<DataType.MULTINOMIAL: 2>, timestamp_key='drifting_gratings.timestamps', value_key='drifting_gratings.temporal_frequency_id', loss_fn=CrossEntropyLoss()), 'gabor_orientation': ModalitySpec(id=14, dim=4, type=<DataType.MULTINOMIAL: 2>, timestamp_key='gabors.timestamps', value_key='gabors.gabors_orientation', loss_fn=CrossEntropyLoss()), 'gabor_pos_2d': ModalitySpec(id=15, dim=2, type=<DataType.CONTINUOUS: 0>, timestamp_key='gabors.timestamps', value_key='gabors.pos_2d', loss_fn=MSELoss()), 'gaze_pos_2d': ModalitySpec(id=17, dim=2, type=<DataType.CONTINUOUS: 0>, timestamp_key='gaze.timestamps', value_key='gaze.pos_2d', loss_fn=MSELoss()), 'locally_sparse_noise_frame': ModalitySpec(id=9, dim=8000, type=<DataType.MULTINOMIAL: 2>, timestamp_key='locally_sparse_noise.timestamps', value_key='locally_sparse_noise.frame', loss_fn=CrossEntropyLoss()), 'natural_movie_one_frame': ModalitySpec(id=6, dim=900, type=<DataType.MULTINOMIAL: 2>, timestamp_key='natural_movie_one.timestamps', value_key='natural_movie_one.frame', loss_fn=CrossEntropyLoss()), 'natural_movie_three_frame': ModalitySpec(id=8, dim=3600, type=<DataType.MULTINOMIAL: 2>, timestamp_key='natural_movie_three.timestamps', value_key='natural_movie_three.frame', loss_fn=CrossEntropyLoss()), 'natural_movie_two_frame': ModalitySpec(id=7, dim=900, type=<DataType.MULTINOMIAL: 2>, timestamp_key='natural_movie_two.timestamps', value_key='natural_movie_two.frame', loss_fn=CrossEntropyLoss()), 'natural_scenes': ModalitySpec(id=13, dim=119, type=<DataType.MULTINOMIAL: 2>, timestamp_key='natural_scenes.timestamps', value_key='natural_scenes.frame', loss_fn=CrossEntropyLoss()), 'pupil_location': ModalitySpec(id=18, dim=2, type=<DataType.CONTINUOUS: 0>, timestamp_key='pupil.timestamps', value_key='pupil.location', loss_fn=MSELoss()), 'pupil_size_2d': ModalitySpec(id=19, dim=2, type=<DataType.CONTINUOUS: 0>, timestamp_key='pupil.timestamps', value_key='pupil.size_2d', loss_fn=MSELoss()), 'running_speed': ModalitySpec(id=16, dim=1, type=<DataType.CONTINUOUS: 0>, timestamp_key='running.timestamps', value_key='running.running_speed', loss_fn=MSELoss()), 'static_gratings_orientation': ModalitySpec(id=10, dim=6, type=<DataType.MULTINOMIAL: 2>, timestamp_key='static_gratings.timestamps', value_key='static_gratings.orientation_id', loss_fn=CrossEntropyLoss()), 'static_gratings_phase': ModalitySpec(id=12, dim=5, type=<DataType.MULTINOMIAL: 2>, timestamp_key='static_gratings.timestamps', value_key='static_gratings.phase_id', loss_fn=CrossEntropyLoss()), 'static_gratings_spatial_frequency': ModalitySpec(id=11, dim=5, type=<DataType.MULTINOMIAL: 2>, timestamp_key='static_gratings.timestamps', value_key='static_gratings.spatial_frequency_id', loss_fn=CrossEntropyLoss())}, latent_step, num_latents_per_step=64, dim=512, depth=2, dim_head=64, cross_heads=1, self_heads=8, ffn_dropout=0.2, lin_dropout=0.4, atn_dropout=0.0, emb_init_scale=0.02, t_min=0.0001, t_max=4.0)[source]¶
Bases:
Module
POYO+ model from Azabou et al. 2025, Multi-session, multi-task neural decoding from distinct cell-types and brain regions.
POYO+ is a transformer-based model for neural decoding from population recordings. It extends the POYO architecture with multiple task-specific decoders.
The model processes neural spike sequences through the following steps:
- Input tokens are constructed by combining unit embeddings, token type embeddings,
and time embeddings for each spike in the sequence.
- The input sequence is compressed using cross-attention, where learnable latent
tokens (each with an associated timestamp) attend to the input tokens.
- The compressed latent token representations undergo further refinement through
multiple self-attention processing layers.
- Query tokens are constructed for the desired outputs by combining task embeddings,
session embeddings, and output timestamps.
- These query tokens attend to the processed latent representations through
cross-attention, producing outputs in the model’s dimensional space (dim).
- Finally, task-specific linear layers map the outputs from the model dimension
to the appropriate output dimension required by each task.
- Parameters:
sequence_length (
float
) – Maximum duration of the input spike sequence (in seconds)readout_specs (
Dict
[str
,ModalitySpec
]) – Specifications for each prediction task. This is a dictionary with strings as keys (task names), andtorch_brain.registry.ModalitySpec
as values. One key-value pair for each prediction task.latent_step (
float
) – Timestep of the latent grid (in seconds)num_latents_per_step (
int
) – Number of unique latent tokens (repeated at every latent step)dim (
int
) – Dimension of all embeddingsdepth (
int
) – Number of processing layersdim_head (
int
) – Dimension of each attention headcross_heads (
int
) – Number of attention heads used in a cross-attention layerself_heads (
int
) – Number of attention heads used in a self-attention layerffn_dropout (
float
) – Dropout rate for feed-forward networkslin_dropout (
float
) – Dropout rate for linear layersatn_dropout (
float
) – Dropout rate for attentionemb_init_scale (
float
) – Scale for embedding initializationt_min (
float
) – Minimum timestamp resolution for rotary embeddingst_max (
float
) – Maximum timestamp resolution for rotary embeddings
- forward(*, input_unit_index, input_timestamps, input_token_type, input_mask=None, latent_index, latent_timestamps, output_session_index, output_timestamps, output_decoder_index, unpack_output=False)[source]¶
Forward pass of the POYO+ model.
The model processes input spike sequences through its encoder-processor-decoder architecture to generate task-specific predictions.
- Parameters:
input_unit_index (
Tensor
) – Indices of input unitsinput_timestamps (
Tensor
) – Timestamps of input spikesinput_token_type (
Tensor
) – Type of input tokensinput_mask (
Optional
[Tensor
]) – Mask for input sequencelatent_index (
Tensor
) – Indices for latent tokenslatent_timestamps (
Tensor
) – Timestamps for latent tokenssession_index – Index of the recording session
output_timestamps (
Tensor
) – Timestamps for output predictionsoutput_decoder_index (
Tensor
) – Indices indicating which decoder to useoutput_batch_index – Optional batch indices for outputs
output_values – Ground truth values for supervised training
output_weights – Optional weights for loss computation
- Returns:
- A list of dictionaries, each containing the predicted outputs for a
given task
Total loss value
Dictionary of per-task losses
- Return type:
Tuple containing