CalciumPOYOPlus#
- class torch_brain.models.CalciumPOYOPlus(*, sequence_length, readout_specs={'arm_velocity_2d': ModalitySpec(id=3, dim=2, type=<DataType.CONTINUOUS: 0>, timestamp_key='behavior.timestamps', value_key='behavior.hand_vel', loss_fn=MSELoss()), 'cursor_position_2d': ModalitySpec(id=2, dim=2, type=<DataType.CONTINUOUS: 0>, timestamp_key='cursor.timestamps', value_key='cursor.pos', loss_fn=MSELoss()), 'cursor_velocity_2d': ModalitySpec(id=1, dim=2, type=<DataType.CONTINUOUS: 0>, timestamp_key='cursor.timestamps', value_key='cursor.vel', loss_fn=MSELoss()), 'drifting_gratings_orientation': ModalitySpec(id=4, dim=8, type=<DataType.MULTINOMIAL: 2>, timestamp_key='drifting_gratings.timestamps', value_key='drifting_gratings.orientation_id', loss_fn=CrossEntropyLoss()), 'drifting_gratings_temporal_frequency': ModalitySpec(id=5, dim=5, type=<DataType.MULTINOMIAL: 2>, timestamp_key='drifting_gratings.timestamps', value_key='drifting_gratings.temporal_frequency_id', loss_fn=CrossEntropyLoss()), 'gabor_orientation': ModalitySpec(id=14, dim=4, type=<DataType.MULTINOMIAL: 2>, timestamp_key='gabors.timestamps', value_key='gabors.gabors_orientation', loss_fn=CrossEntropyLoss()), 'gabor_pos_2d': ModalitySpec(id=15, dim=2, type=<DataType.CONTINUOUS: 0>, timestamp_key='gabors.timestamps', value_key='gabors.pos_2d', loss_fn=MSELoss()), 'gaze_pos_2d': ModalitySpec(id=17, dim=2, type=<DataType.CONTINUOUS: 0>, timestamp_key='gaze.timestamps', value_key='gaze.pos_2d', loss_fn=MSELoss()), 'locally_sparse_noise_frame': ModalitySpec(id=9, dim=8000, type=<DataType.MULTINOMIAL: 2>, timestamp_key='locally_sparse_noise.timestamps', value_key='locally_sparse_noise.frame', loss_fn=CrossEntropyLoss()), 'natural_movie_one_frame': ModalitySpec(id=6, dim=900, type=<DataType.MULTINOMIAL: 2>, timestamp_key='natural_movie_one.timestamps', value_key='natural_movie_one.frame', loss_fn=CrossEntropyLoss()), 'natural_movie_three_frame': ModalitySpec(id=8, dim=3600, type=<DataType.MULTINOMIAL: 2>, timestamp_key='natural_movie_three.timestamps', value_key='natural_movie_three.frame', loss_fn=CrossEntropyLoss()), 'natural_movie_two_frame': ModalitySpec(id=7, dim=900, type=<DataType.MULTINOMIAL: 2>, timestamp_key='natural_movie_two.timestamps', value_key='natural_movie_two.frame', loss_fn=CrossEntropyLoss()), 'natural_scenes': ModalitySpec(id=13, dim=119, type=<DataType.MULTINOMIAL: 2>, timestamp_key='natural_scenes.timestamps', value_key='natural_scenes.frame', loss_fn=CrossEntropyLoss()), 'pupil_location': ModalitySpec(id=18, dim=2, type=<DataType.CONTINUOUS: 0>, timestamp_key='pupil.timestamps', value_key='pupil.location', loss_fn=MSELoss()), 'pupil_size_2d': ModalitySpec(id=19, dim=2, type=<DataType.CONTINUOUS: 0>, timestamp_key='pupil.timestamps', value_key='pupil.size_2d', loss_fn=MSELoss()), 'running_speed': ModalitySpec(id=16, dim=1, type=<DataType.CONTINUOUS: 0>, timestamp_key='running.timestamps', value_key='running.running_speed', loss_fn=MSELoss()), 'static_gratings_orientation': ModalitySpec(id=10, dim=6, type=<DataType.MULTINOMIAL: 2>, timestamp_key='static_gratings.timestamps', value_key='static_gratings.orientation_id', loss_fn=CrossEntropyLoss()), 'static_gratings_phase': ModalitySpec(id=12, dim=5, type=<DataType.MULTINOMIAL: 2>, timestamp_key='static_gratings.timestamps', value_key='static_gratings.phase_id', loss_fn=CrossEntropyLoss()), 'static_gratings_spatial_frequency': ModalitySpec(id=11, dim=5, type=<DataType.MULTINOMIAL: 2>, timestamp_key='static_gratings.timestamps', value_key='static_gratings.spatial_frequency_id', loss_fn=CrossEntropyLoss())}, latent_step, num_latents_per_step=64, dim=512, depth=2, dim_head=64, cross_heads=1, self_heads=8, ffn_dropout=0.2, lin_dropout=0.4, atn_dropout=0.0, emb_init_scale=0.02, t_min=0.0001, t_max=2.0627)[source]#
Bases:
torch.nn.modules.module.ModuleCalcium POYO+ model from Azabou et al. 2025, Multi-session, multi-task neural decoding from distinct cell-types and brain regions.
Calcium POYO+ is a transformer-based model for neural decoding from calcium imaging recordings. It extends the POYO+ architecture with a calcium value map.
- forward(*, input_unit_index, input_timestamps, input_values, input_mask=None, latent_index, latent_timestamps, output_session_index, output_timestamps, output_decoder_index, unpack_output=False)[source]#
Forward pass of the POYO+ model.
The model processes input spike sequences through its encoder-processor-decoder architecture to generate task-specific predictions.
- Parameters:
input_unit_index (
Tensor) – Indices of input unitsinput_timestamps (
Tensor) – Timestamps of input spikesinput_values (
Tensor) – Calcium values of input sequencelatent_index (
Tensor) – Indices for latent tokenslatent_timestamps (
Tensor) – Timestamps for latent tokenssession_index – Index of the recording session
output_timestamps (
Tensor) – Timestamps for output predictionsoutput_decoder_index (
Tensor) – Indices indicating which decoder to useoutput_batch_index – Optional batch indices for outputs
output_values – Ground truth values for supervised training
output_weights – Optional weights for loss computation
- Returns:
- A list of dictionaries, each containing the predicted outputs for a
given task
Total loss value
Dictionary of per-task losses
- Return type:
Tuple containing