POYO#

class torch_brain.models.POYO(*, sequence_length, latent_step, num_latents_per_step=64, dim_out, dim=512, depth=2, dim_head=64, cross_heads=1, self_heads=8, ffn_dropout=0.2, lin_dropout=0.4, atn_dropout=0.0, emb_init_scale=0.02, t_min=0.0001, t_max=2.0627)[source]#

Bases: torch.nn.modules.module.Module

POYO model from Azabou et al. 2023, A Unified, Scalable Framework for Neural Population Decoding.

POYO is a transformer-based model for neural decoding from electrophysiological recordings.

  1. Input tokens are constructed by combining unit embeddings, token type embeddings, and time embeddings for each spike in the sequence.

  2. The input sequence is compressed using cross-attention, where learnable latent tokens (each with an associated timestamp) attend to the input tokens.

  3. The compressed latent token representations undergo further refinement through multiple self-attention processing layers.

  4. Query tokens are constructed for the desired outputs by combining session embeddings, and output timestamps.

  5. These query tokens attend to the processed latent representations through cross-attention, producing outputs in the model’s dimensional space (dim).

  6. Finally, a task-specific linear layer maps the outputs from the model dimension to the appropriate output dimension.

Parameters:
  • sequence_length (float) – Maximum duration of the input spike sequence (in seconds)

  • latent_step (float) – Timestep of the latent grid (in seconds)

  • num_latents_per_step (int) – Number of unique latent tokens (repeated at every latent step)

  • dim (int) – Hidden dimension of the model

  • dim_out (int) – Output dimension of the model

  • depth (int) – Number of processing layers (self-attentions in the latent space)

  • dim_head (int) – Dimension of each attention head

  • cross_heads (int) – Number of attention heads used in a cross-attention layer

  • self_heads (int) – Number of attention heads used in a self-attention layer

  • ffn_dropout (float) – Dropout rate for feed-forward networks

  • lin_dropout (float) – Dropout rate for linear layers

  • atn_dropout (float) – Dropout rate for attention

  • emb_init_scale (float) – Scale for embedding initialization

  • t_min (float) – Minimum timestamp resolution for rotary embeddings

  • t_max (float) – Maximum timestamp resolution for rotary embeddings

forward(*, input_unit_index, input_timestamps, input_token_type, input_mask=None, latent_index, latent_timestamps, session_index, output_timestamps, output_mask=None, unpack_output=False)[source]#

Forward pass of the POYO model.

The model processes input spike sequences through its encoder-processor-decoder architecture to generate task-specific predictions.

Parameters:
  • input_unit_index (Tensor) – Unit indices of shape \((B, N_{in})\).

  • input_timestamps (Tensor) – Spike timestamps of shape \((B, N_{in})\).

  • input_token_type (Tensor) – Token type indices of shape \((B, N_{in})\).

  • input_mask (Optional[Tensor]) – Boolean mask of shape \((B, N_{in})\).

  • latent_index (Tensor) – Latent token indices of shape \((B, N_{lat})\).

  • latent_timestamps (Tensor) – Latent token timestamps of shape \((B, N_{lat})\).

  • session_index (Tensor) – Session indices of shape \((B,)\).

  • output_timestamps (Tensor) – Output query timestamps of shape \((B, N_{out})\).

  • output_mask (Optional[Tensor]) – Boolean mask of shape \((B, N_{out})\). Required when unpack_output=True.

  • unpack_output (bool) – If False, returns a padded tensor of shape \((B, N_{out}, D_{out})\); use output_mask to index valid entries. If True, returns a list of \(B\) tensors each of shape \((N_{out,i}, D_{out})\) containing only the valid outputs.

Return type:

Union[Tensor, List[Tensor]]

Returns:

Shape \((B, N_{out}, D_{out})\) when unpack_output=False, or a list of \(B\) tensors of shape \((N_{out,i}, D_{out})\) when unpack_output=True.

tokenize(data)[source]#

Tokenizer used to tokenize Data for the POYO model.

This tokenizer can be called as a transform. If you are applying multiple transforms, make sure to apply this one last.

This code runs on CPU. Do not access GPU tensors inside this function.

Return type:

Dict

init_vocabs(dataset)[source]#

Initializes model’s unit_emb and session_emb vocabularies.

Parameters:

dataset (Dataset) – A Dataset with SpikingDatasetMixin.

classmethod load_pretrained(checkpoint_path)[source]#

Load a pretrained POYO model from a checkpoint file.

Parameters:

checkpoint_path (str | Path) – Path to the checkpoint file containing model weights and hyperparameters.

Return type:

POYO

Returns:

An instance of the POYO model with weights loaded from the checkpoint.