torch_brain.models

Brain decoders

class POYO(*, sequence_length, readout_spec, latent_step, num_latents_per_step=64, dim=512, depth=2, dim_head=64, cross_heads=1, self_heads=8, ffn_dropout=0.2, lin_dropout=0.4, atn_dropout=0.0, emb_init_scale=0.02, t_min=0.0001, t_max=4.0)[source]

Bases: Module

POYO model from Azabou et al. 2023, A Unified, Scalable Framework for Neural Population Decoding.

POYO is a transformer-based model for neural decoding from electrophysiological recordings.

  1. Input tokens are constructed by combining unit embeddings, token type embeddings,

    and time embeddings for each spike in the sequence.

  2. The input sequence is compressed using cross-attention, where learnable latent

    tokens (each with an associated timestamp) attend to the input tokens.

  3. The compressed latent token representations undergo further refinement through

    multiple self-attention processing layers.

  4. Query tokens are constructed for the desired outputs by combining session

    embeddings, and output timestamps.

  5. These query tokens attend to the processed latent representations through

    cross-attention, producing outputs in the model’s dimensional space (dim).

  6. Finally, a task-specific linear layer maps the outputs from the model dimension

    to the appropriate output dimension.

Parameters:
  • sequence_length (float) – Maximum duration of the input spike sequence (in seconds)

  • readout_spec (ModalitySpec) – A torch_brain.registry.ModalitySpec specifying readout properties

  • latent_step (float) – Timestep of the latent grid (in seconds)

  • num_latents_per_step (int) – Number of unique latent tokens (repeated at every latent step)

  • dim (int) – Hidden dimension of the model

  • depth (int) – Number of processing layers (self-attentions in the latent space)

  • dim_head (int) – Dimension of each attention head

  • cross_heads (int) – Number of attention heads used in a cross-attention layer

  • self_heads (int) – Number of attention heads used in a self-attention layer

  • ffn_dropout (float) – Dropout rate for feed-forward networks

  • lin_dropout (float) – Dropout rate for linear layers

  • atn_dropout (float) – Dropout rate for attention

  • emb_init_scale (float) – Scale for embedding initialization

  • t_min (float) – Minimum timestamp resolution for rotary embeddings

  • t_max (float) – Maximum timestamp resolution for rotary embeddings

forward(*, input_unit_index, input_timestamps, input_token_type, input_mask=None, latent_index, latent_timestamps, output_session_index, output_timestamps, output_mask=None, unpack_output=False)[source]

Forward pass of the POYO model.

The model processes input spike sequences through its encoder-processor-decoder architecture to generate task-specific predictions.

Parameters:
  • input_unit_index (Tensor) – Indices of input units

  • input_timestamps (Tensor) – Timestamps of input spikes

  • input_token_type (Tensor) – Type of input tokens

  • input_mask (Optional[Tensor]) – Mask for input sequence

  • latent_index (Tensor) – Indices for latent tokens

  • latent_timestamps (Tensor) – Timestamps for latent tokens

  • output_session_index (Tensor) – Index of the recording session

  • output_timestamps (Tensor) – Timestamps for output predictions

  • output_mask (Optional[Tensor]) – A mask of the same size as output_timestamps. True implies that particular timestamp is a valid query for POYO. This is required iff unpack_output is set to True.

  • unpack_output (bool) – If False, this function will return a padded tensor of shape (batch size, num of max output queries in batch, dim_out). In this case you have to use output_mask externally to only look at valid outputs. If True, this will return a list of Tensors: the length of the list is equal to batch size, the shape of i^th Tensor is (num of valid output queries for i^th sample, d_out).

Return type:

Union[Tensor, List[Tensor]]

Returns:

A torch.Tensor of shape (batch, n_out, dim_out) containing the predicted outputs corresponding to output_timestamps.

tokenize(data)[source]

Tokenizer used to tokenize Data for the POYO model.

This tokenizer can be called as a transform. If you are applying multiple transforms, make sure to apply this one last.

This code runs on CPU. Do not access GPU tensors inside this function.

Return type:

Dict

class POYOPlus(*, sequence_length, readout_specs={'arm_velocity_2d': ModalitySpec(id=3, dim=2, type=<DataType.CONTINUOUS: 0>, timestamp_key='behavior.timestamps', value_key='behavior.hand_vel', loss_fn=MSELoss()), 'cursor_position_2d': ModalitySpec(id=2, dim=2, type=<DataType.CONTINUOUS: 0>, timestamp_key='cursor.timestamps', value_key='cursor.pos', loss_fn=MSELoss()), 'cursor_velocity_2d': ModalitySpec(id=1, dim=2, type=<DataType.CONTINUOUS: 0>, timestamp_key='cursor.timestamps', value_key='cursor.vel', loss_fn=MSELoss()), 'drifting_gratings_orientation': ModalitySpec(id=4, dim=8, type=<DataType.MULTINOMIAL: 2>, timestamp_key='drifting_gratings.timestamps', value_key='drifting_gratings.orientation_id', loss_fn=CrossEntropyLoss()), 'drifting_gratings_temporal_frequency': ModalitySpec(id=5, dim=5, type=<DataType.MULTINOMIAL: 2>, timestamp_key='drifting_gratings.timestamps', value_key='drifting_gratings.temporal_frequency_id', loss_fn=CrossEntropyLoss()), 'gabor_orientation': ModalitySpec(id=14, dim=4, type=<DataType.MULTINOMIAL: 2>, timestamp_key='gabors.timestamps', value_key='gabors.gabors_orientation', loss_fn=CrossEntropyLoss()), 'gabor_pos_2d': ModalitySpec(id=15, dim=2, type=<DataType.CONTINUOUS: 0>, timestamp_key='gabors.timestamps', value_key='gabors.pos_2d', loss_fn=MSELoss()), 'gaze_pos_2d': ModalitySpec(id=17, dim=2, type=<DataType.CONTINUOUS: 0>, timestamp_key='gaze.timestamps', value_key='gaze.pos_2d', loss_fn=MSELoss()), 'locally_sparse_noise_frame': ModalitySpec(id=9, dim=8000, type=<DataType.MULTINOMIAL: 2>, timestamp_key='locally_sparse_noise.timestamps', value_key='locally_sparse_noise.frame', loss_fn=CrossEntropyLoss()), 'natural_movie_one_frame': ModalitySpec(id=6, dim=900, type=<DataType.MULTINOMIAL: 2>, timestamp_key='natural_movie_one.timestamps', value_key='natural_movie_one.frame', loss_fn=CrossEntropyLoss()), 'natural_movie_three_frame': ModalitySpec(id=8, dim=3600, type=<DataType.MULTINOMIAL: 2>, timestamp_key='natural_movie_three.timestamps', value_key='natural_movie_three.frame', loss_fn=CrossEntropyLoss()), 'natural_movie_two_frame': ModalitySpec(id=7, dim=900, type=<DataType.MULTINOMIAL: 2>, timestamp_key='natural_movie_two.timestamps', value_key='natural_movie_two.frame', loss_fn=CrossEntropyLoss()), 'natural_scenes': ModalitySpec(id=13, dim=119, type=<DataType.MULTINOMIAL: 2>, timestamp_key='natural_scenes.timestamps', value_key='natural_scenes.frame', loss_fn=CrossEntropyLoss()), 'pupil_location': ModalitySpec(id=18, dim=2, type=<DataType.CONTINUOUS: 0>, timestamp_key='pupil.timestamps', value_key='pupil.location', loss_fn=MSELoss()), 'pupil_size_2d': ModalitySpec(id=19, dim=2, type=<DataType.CONTINUOUS: 0>, timestamp_key='pupil.timestamps', value_key='pupil.size_2d', loss_fn=MSELoss()), 'running_speed': ModalitySpec(id=16, dim=1, type=<DataType.CONTINUOUS: 0>, timestamp_key='running.timestamps', value_key='running.running_speed', loss_fn=MSELoss()), 'static_gratings_orientation': ModalitySpec(id=10, dim=6, type=<DataType.MULTINOMIAL: 2>, timestamp_key='static_gratings.timestamps', value_key='static_gratings.orientation_id', loss_fn=CrossEntropyLoss()), 'static_gratings_phase': ModalitySpec(id=12, dim=5, type=<DataType.MULTINOMIAL: 2>, timestamp_key='static_gratings.timestamps', value_key='static_gratings.phase_id', loss_fn=CrossEntropyLoss()), 'static_gratings_spatial_frequency': ModalitySpec(id=11, dim=5, type=<DataType.MULTINOMIAL: 2>, timestamp_key='static_gratings.timestamps', value_key='static_gratings.spatial_frequency_id', loss_fn=CrossEntropyLoss())}, latent_step, num_latents_per_step=64, dim=512, depth=2, dim_head=64, cross_heads=1, self_heads=8, ffn_dropout=0.2, lin_dropout=0.4, atn_dropout=0.0, emb_init_scale=0.02, t_min=0.0001, t_max=4.0)[source]

Bases: Module

POYO+ model from Azabou et al. 2025, Multi-session, multi-task neural decoding from distinct cell-types and brain regions.

POYO+ is a transformer-based model for neural decoding from population recordings. It extends the POYO architecture with multiple task-specific decoders.

The model processes neural spike sequences through the following steps:

  1. Input tokens are constructed by combining unit embeddings, token type embeddings,

    and time embeddings for each spike in the sequence.

  2. The input sequence is compressed using cross-attention, where learnable latent

    tokens (each with an associated timestamp) attend to the input tokens.

  3. The compressed latent token representations undergo further refinement through

    multiple self-attention processing layers.

  4. Query tokens are constructed for the desired outputs by combining task embeddings,

    session embeddings, and output timestamps.

  5. These query tokens attend to the processed latent representations through

    cross-attention, producing outputs in the model’s dimensional space (dim).

  6. Finally, task-specific linear layers map the outputs from the model dimension

    to the appropriate output dimension required by each task.

Parameters:
  • sequence_length (float) – Maximum duration of the input spike sequence (in seconds)

  • readout_specs (Dict[str, ModalitySpec]) – Specifications for each prediction task. This is a dictionary with strings as keys (task names), and torch_brain.registry.ModalitySpec as values. One key-value pair for each prediction task.

  • latent_step (float) – Timestep of the latent grid (in seconds)

  • num_latents_per_step (int) – Number of unique latent tokens (repeated at every latent step)

  • dim (int) – Dimension of all embeddings

  • depth (int) – Number of processing layers

  • dim_head (int) – Dimension of each attention head

  • cross_heads (int) – Number of attention heads used in a cross-attention layer

  • self_heads (int) – Number of attention heads used in a self-attention layer

  • ffn_dropout (float) – Dropout rate for feed-forward networks

  • lin_dropout (float) – Dropout rate for linear layers

  • atn_dropout (float) – Dropout rate for attention

  • emb_init_scale (float) – Scale for embedding initialization

  • t_min (float) – Minimum timestamp resolution for rotary embeddings

  • t_max (float) – Maximum timestamp resolution for rotary embeddings

forward(*, input_unit_index, input_timestamps, input_token_type, input_mask=None, latent_index, latent_timestamps, output_session_index, output_timestamps, output_decoder_index, unpack_output=False)[source]

Forward pass of the POYO+ model.

The model processes input spike sequences through its encoder-processor-decoder architecture to generate task-specific predictions.

Parameters:
  • input_unit_index (Tensor) – Indices of input units

  • input_timestamps (Tensor) – Timestamps of input spikes

  • input_token_type (Tensor) – Type of input tokens

  • input_mask (Optional[Tensor]) – Mask for input sequence

  • latent_index (Tensor) – Indices for latent tokens

  • latent_timestamps (Tensor) – Timestamps for latent tokens

  • session_index – Index of the recording session

  • output_timestamps (Tensor) – Timestamps for output predictions

  • output_decoder_index (Tensor) – Indices indicating which decoder to use

  • output_batch_index – Optional batch indices for outputs

  • output_values – Ground truth values for supervised training

  • output_weights – Optional weights for loss computation

Returns:

  • A list of dictionaries, each containing the predicted outputs for a

    given task

  • Total loss value

  • Dictionary of per-task losses

Return type:

Tuple containing

tokenize(data)[source]

Tokenizer used to tokenize Data for the POYO+ model.

This tokenizer can be called as a transform. If you are applying multiple transforms, make sure to apply this one last.

This code runs on CPU. Do not access GPU tensors inside this function.

Return type:

Dict