CalciumPOYOPlus#

class torch_brain.models.CalciumPOYOPlus(*, sequence_length, readout_specs={'arm_velocity_2d': ModalitySpec(id=3, dim=2, type=<DataType.CONTINUOUS: 0>, timestamp_key='behavior.timestamps', value_key='behavior.hand_vel', loss_fn=MSELoss()), 'cursor_position_2d': ModalitySpec(id=2, dim=2, type=<DataType.CONTINUOUS: 0>, timestamp_key='cursor.timestamps', value_key='cursor.pos', loss_fn=MSELoss()), 'cursor_velocity_2d': ModalitySpec(id=1, dim=2, type=<DataType.CONTINUOUS: 0>, timestamp_key='cursor.timestamps', value_key='cursor.vel', loss_fn=MSELoss()), 'drifting_gratings_orientation': ModalitySpec(id=4, dim=8, type=<DataType.MULTINOMIAL: 2>, timestamp_key='drifting_gratings.timestamps', value_key='drifting_gratings.orientation_id', loss_fn=CrossEntropyLoss()), 'drifting_gratings_temporal_frequency': ModalitySpec(id=5, dim=5, type=<DataType.MULTINOMIAL: 2>, timestamp_key='drifting_gratings.timestamps', value_key='drifting_gratings.temporal_frequency_id', loss_fn=CrossEntropyLoss()), 'gabor_orientation': ModalitySpec(id=14, dim=4, type=<DataType.MULTINOMIAL: 2>, timestamp_key='gabors.timestamps', value_key='gabors.gabors_orientation', loss_fn=CrossEntropyLoss()), 'gabor_pos_2d': ModalitySpec(id=15, dim=2, type=<DataType.CONTINUOUS: 0>, timestamp_key='gabors.timestamps', value_key='gabors.pos_2d', loss_fn=MSELoss()), 'gaze_pos_2d': ModalitySpec(id=17, dim=2, type=<DataType.CONTINUOUS: 0>, timestamp_key='gaze.timestamps', value_key='gaze.pos_2d', loss_fn=MSELoss()), 'locally_sparse_noise_frame': ModalitySpec(id=9, dim=8000, type=<DataType.MULTINOMIAL: 2>, timestamp_key='locally_sparse_noise.timestamps', value_key='locally_sparse_noise.frame', loss_fn=CrossEntropyLoss()), 'natural_movie_one_frame': ModalitySpec(id=6, dim=900, type=<DataType.MULTINOMIAL: 2>, timestamp_key='natural_movie_one.timestamps', value_key='natural_movie_one.frame', loss_fn=CrossEntropyLoss()), 'natural_movie_three_frame': ModalitySpec(id=8, dim=3600, type=<DataType.MULTINOMIAL: 2>, timestamp_key='natural_movie_three.timestamps', value_key='natural_movie_three.frame', loss_fn=CrossEntropyLoss()), 'natural_movie_two_frame': ModalitySpec(id=7, dim=900, type=<DataType.MULTINOMIAL: 2>, timestamp_key='natural_movie_two.timestamps', value_key='natural_movie_two.frame', loss_fn=CrossEntropyLoss()), 'natural_scenes': ModalitySpec(id=13, dim=119, type=<DataType.MULTINOMIAL: 2>, timestamp_key='natural_scenes.timestamps', value_key='natural_scenes.frame', loss_fn=CrossEntropyLoss()), 'pupil_location': ModalitySpec(id=18, dim=2, type=<DataType.CONTINUOUS: 0>, timestamp_key='pupil.timestamps', value_key='pupil.location', loss_fn=MSELoss()), 'pupil_size_2d': ModalitySpec(id=19, dim=2, type=<DataType.CONTINUOUS: 0>, timestamp_key='pupil.timestamps', value_key='pupil.size_2d', loss_fn=MSELoss()), 'running_speed': ModalitySpec(id=16, dim=1, type=<DataType.CONTINUOUS: 0>, timestamp_key='running.timestamps', value_key='running.running_speed', loss_fn=MSELoss()), 'static_gratings_orientation': ModalitySpec(id=10, dim=6, type=<DataType.MULTINOMIAL: 2>, timestamp_key='static_gratings.timestamps', value_key='static_gratings.orientation_id', loss_fn=CrossEntropyLoss()), 'static_gratings_phase': ModalitySpec(id=12, dim=5, type=<DataType.MULTINOMIAL: 2>, timestamp_key='static_gratings.timestamps', value_key='static_gratings.phase_id', loss_fn=CrossEntropyLoss()), 'static_gratings_spatial_frequency': ModalitySpec(id=11, dim=5, type=<DataType.MULTINOMIAL: 2>, timestamp_key='static_gratings.timestamps', value_key='static_gratings.spatial_frequency_id', loss_fn=CrossEntropyLoss())}, latent_step, num_latents_per_step=64, dim=512, depth=2, dim_head=64, cross_heads=1, self_heads=8, ffn_dropout=0.2, lin_dropout=0.4, atn_dropout=0.0, emb_init_scale=0.02, t_min=0.0001, t_max=2.0627)[source]#

Bases: torch.nn.modules.module.Module

Calcium POYO+ model from Azabou et al. 2025, Multi-session, multi-task neural decoding from distinct cell-types and brain regions.

Calcium POYO+ is a transformer-based model for neural decoding from calcium imaging recordings. It extends the POYO+ architecture with a calcium value map.

forward(*, input_unit_index, input_timestamps, input_values, input_mask=None, latent_index, latent_timestamps, output_session_index, output_timestamps, output_decoder_index, unpack_output=False)[source]#

Forward pass of the POYO+ model.

The model processes input spike sequences through its encoder-processor-decoder architecture to generate task-specific predictions.

Parameters:
  • input_unit_index (Tensor) – Indices of input units

  • input_timestamps (Tensor) – Timestamps of input spikes

  • input_values (Tensor) – Calcium values of input sequence

  • input_mask (Optional[Tensor]) – Mask for input sequence

  • latent_index (Tensor) – Indices for latent tokens

  • latent_timestamps (Tensor) – Timestamps for latent tokens

  • session_index – Index of the recording session

  • output_timestamps (Tensor) – Timestamps for output predictions

  • output_decoder_index (Tensor) – Indices indicating which decoder to use

  • output_batch_index – Optional batch indices for outputs

  • output_values – Ground truth values for supervised training

  • output_weights – Optional weights for loss computation

Returns:

  • A list of dictionaries, each containing the predicted outputs for a

    given task

  • Total loss value

  • Dictionary of per-task losses

Return type:

Tuple containing

tokenize(data)[source]#

Tokenizer used to tokenize Data for the POYO+ model.

This tokenizer can be called as a transform. If you are applying multiple transforms, make sure to apply this one last.

This code runs on CPU. Do not access GPU tensors inside this function.

Return type:

Dict