POYO#
- class torch_brain.models.POYO(*, sequence_length, readout_spec, latent_step, num_latents_per_step=64, dim=512, depth=2, dim_head=64, cross_heads=1, self_heads=8, ffn_dropout=0.2, lin_dropout=0.4, atn_dropout=0.0, emb_init_scale=0.02, t_min=0.0001, t_max=2.0627)[source]#
Bases:
torch.nn.modules.module.ModulePOYO model from Azabou et al. 2023, A Unified, Scalable Framework for Neural Population Decoding.
POYO is a transformer-based model for neural decoding from electrophysiological recordings.
- Input tokens are constructed by combining unit embeddings, token type embeddings,
and time embeddings for each spike in the sequence.
- The input sequence is compressed using cross-attention, where learnable latent
tokens (each with an associated timestamp) attend to the input tokens.
- The compressed latent token representations undergo further refinement through
multiple self-attention processing layers.
- Query tokens are constructed for the desired outputs by combining session
embeddings, and output timestamps.
- These query tokens attend to the processed latent representations through
cross-attention, producing outputs in the model’s dimensional space (dim).
- Finally, a task-specific linear layer maps the outputs from the model dimension
to the appropriate output dimension.
- Parameters:
sequence_length (
float) – Maximum duration of the input spike sequence (in seconds)readout_spec (
ModalitySpec) – Atorch_brain.registry.ModalitySpecspecifying readout propertieslatent_step (
float) – Timestep of the latent grid (in seconds)num_latents_per_step (
int) – Number of unique latent tokens (repeated at every latent step)dim (
int) – Hidden dimension of the modeldepth (
int) – Number of processing layers (self-attentions in the latent space)dim_head (
int) – Dimension of each attention headcross_heads (
int) – Number of attention heads used in a cross-attention layerself_heads (
int) – Number of attention heads used in a self-attention layerffn_dropout (
float) – Dropout rate for feed-forward networkslin_dropout (
float) – Dropout rate for linear layersatn_dropout (
float) – Dropout rate for attentionemb_init_scale (
float) – Scale for embedding initializationt_min (
float) – Minimum timestamp resolution for rotary embeddingst_max (
float) – Maximum timestamp resolution for rotary embeddings
- forward(*, input_unit_index, input_timestamps, input_token_type, input_mask=None, latent_index, latent_timestamps, output_session_index, output_timestamps, output_mask=None, unpack_output=False)[source]#
Forward pass of the POYO model.
The model processes input spike sequences through its encoder-processor-decoder architecture to generate task-specific predictions.
- Parameters:
input_unit_index (
Tensor) – Indices of input unitsinput_timestamps (
Tensor) – Timestamps of input spikesinput_token_type (
Tensor) – Type of input tokenslatent_index (
Tensor) – Indices for latent tokenslatent_timestamps (
Tensor) – Timestamps for latent tokensoutput_session_index (
Tensor) – Index of the recording sessionoutput_timestamps (
Tensor) – Timestamps for output predictionsoutput_mask (
Optional[Tensor]) – A mask of the same size as output_timestamps. True implies that particular timestamp is a valid query for POYO. This is required iff unpack_output is set to True.unpack_output (
bool) – If False, this function will return a padded tensor of shape (batch size, num of max output queries in batch, dim_out). In this case you have to use output_mask externally to only look at valid outputs. If True, this will return a list of Tensors: the length of the list is equal to batch size, the shape of i^th Tensor is (num of valid output queries for i^th sample, d_out).
- Return type:
- Returns:
A
torch.Tensorof shape (batch, n_out, dim_out) containing the predicted outputs corresponding to output_timestamps.
- tokenize(data)[source]#
Tokenizer used to tokenize Data for the POYO model.
This tokenizer can be called as a transform. If you are applying multiple transforms, make sure to apply this one last.
This code runs on CPU. Do not access GPU tensors inside this function.
- Return type:
- classmethod load_pretrained(checkpoint_path, readout_spec, skip_readout=False)[source]#
Load a pretrained POYO model from a checkpoint file.
- Parameters:
checkpoint_path (str or Path) – Path to the checkpoint file containing model weights and hyperparameters.
readout_spec (ModalitySpec) – Specification for the readout modality, used to initialize the model.
skip_readout (bool, optional) – If True, the readout layer weights from the checkpoint are ignored and a new readout layer is initialized. Default is False.
- Returns:
An instance of the POYO model with weights loaded from the checkpoint.
- Return type:
- Usage:
model = POYO.load_pretrained(“path/to/checkpoint.ckpt”, readout_spec)
Notes
The checkpoint is expected to contain both model hyperparameters and weights.
If skip_readout is True, the readout layer weights are not loaded from the checkpoint.