POYO#
- class torch_brain.models.POYO(*, sequence_length, latent_step, num_latents_per_step=64, dim_out, dim=512, depth=2, dim_head=64, cross_heads=1, self_heads=8, ffn_dropout=0.2, lin_dropout=0.4, atn_dropout=0.0, emb_init_scale=0.02, t_min=0.0001, t_max=2.0627)[source]#
Bases:
torch.nn.modules.module.ModulePOYO model from Azabou et al. 2023, A Unified, Scalable Framework for Neural Population Decoding.
POYO is a transformer-based model for neural decoding from electrophysiological recordings.
Input tokens are constructed by combining unit embeddings, token type embeddings, and time embeddings for each spike in the sequence.
The input sequence is compressed using cross-attention, where learnable latent tokens (each with an associated timestamp) attend to the input tokens.
The compressed latent token representations undergo further refinement through multiple self-attention processing layers.
Query tokens are constructed for the desired outputs by combining session embeddings, and output timestamps.
These query tokens attend to the processed latent representations through cross-attention, producing outputs in the model’s dimensional space (dim).
Finally, a task-specific linear layer maps the outputs from the model dimension to the appropriate output dimension.
- Parameters:
sequence_length (
float) – Maximum duration of the input spike sequence (in seconds)latent_step (
float) – Timestep of the latent grid (in seconds)num_latents_per_step (
int) – Number of unique latent tokens (repeated at every latent step)dim (
int) – Hidden dimension of the modeldim_out (
int) – Output dimension of the modeldepth (
int) – Number of processing layers (self-attentions in the latent space)dim_head (
int) – Dimension of each attention headcross_heads (
int) – Number of attention heads used in a cross-attention layerself_heads (
int) – Number of attention heads used in a self-attention layerffn_dropout (
float) – Dropout rate for feed-forward networkslin_dropout (
float) – Dropout rate for linear layersatn_dropout (
float) – Dropout rate for attentionemb_init_scale (
float) – Scale for embedding initializationt_min (
float) – Minimum timestamp resolution for rotary embeddingst_max (
float) – Maximum timestamp resolution for rotary embeddings
- forward(*, input_unit_index, input_timestamps, input_token_type, input_mask=None, latent_index, latent_timestamps, session_index, output_timestamps, output_mask=None, unpack_output=False)[source]#
Forward pass of the POYO model.
The model processes input spike sequences through its encoder-processor-decoder architecture to generate task-specific predictions.
- Parameters:
input_unit_index (
Tensor) – Indices of input unitsinput_timestamps (
Tensor) – Timestamps of input spikesinput_token_type (
Tensor) – Type of input tokenslatent_index (
Tensor) – Indices for latent tokenslatent_timestamps (
Tensor) – Timestamps for latent tokenssession_index (
Tensor) – Index of the recording sessionoutput_timestamps (
Tensor) – Timestamps for output predictionsoutput_mask (
Optional[Tensor]) – A mask of the same size as output_timestamps. True implies that particular timestamp is a valid query for POYO. This is required iff unpack_output is set to True.unpack_output (
bool) – If False, this function will return a padded tensor of shape (batch size, num of max output queries in batch, dim_out). In this case you have to use output_mask externally to only look at valid outputs. If True, this will return a list of Tensors: the length of the list is equal to batch size, the shape of i^th Tensor is (num of valid output queries for i^th sample, d_out).
- Return type:
- Returns:
A
torch.Tensorof shape (batch, n_out, dim_out) containing the predicted outputs corresponding to output_timestamps.
- tokenize(data)[source]#
Tokenizer used to tokenize Data for the POYO model.
This tokenizer can be called as a transform. If you are applying multiple transforms, make sure to apply this one last.
This code runs on CPU. Do not access GPU tensors inside this function.
- Return type:
- init_vocabs(dataset)[source]#
Initializes model’s unit_emb and session_emb vocabularies.
- Parameters:
dataset (
Dataset) – ADatasetwithSpikingDatasetMixin.