API Reference#

Look through specific module or search through the entire list of APIs.

All APIs

Object

Description

Dataset

PyTorch Dataset for loading time-slices of neural data recordings from HDF5 files.

DatasetIndex

Index for accessing a specific time interval of a recording within a Dataset.

NestedDataset

Dataset that composes multiple Dataset instances under a single interface.

SpikingDatasetMixin

Mixin class for torch_brain.dataset.Dataset subclasses containing spiking data.

CalciumImagingDatasetMixin

Mixin class for torch_brain.dataset.Dataset subclasses containing calcium imaging data.

MultiChannelDatasetMixin

Mixin class for torch_brain.dataset.Dataset subclasses containing

chain

Wrap an object to specify that it (or any of its members) should be stacked

collate

Extension of PyTorch’s default_collate function to enable more advanced

pad

Wrap an object to specify that it (or any of its members) should be padded to

pad8

Wrap an object to specify that it (or any of its members) should be padded to

pad2d

type obj:

pad2d8

type obj:

track_batch

Wrap an array or tensor to track the batch_index.

track_mask

Wrap an array or tensor to specify that its padding mask should be tracked.

track_mask8

Wrap an array or tensor to specify that its padding mask should be tracked. This

track_mask2d

Wrap an array or tensor to specify that its padding mask should be tracked. This

track_mask2d8

Wrap an array or tensor to specify that its padding mask should be tracked. This

Dataset

RandomFixedWindowSampler

Samples fixed-length windows randomly, given intervals defined in the

SequentialFixedWindowSampler

Samples fixed-length windows sequentially, always in the same order. The

TrialSampler

Randomly samples a single trial interval from the given intervals.

DistributedEvaluationSamplerWrapper

Wraps a sampler to be used in a distributed evaluation setting. Unlike the standard

DistributedStitchingFixedWindowSampler

A sampler designed specifically for evaluation that enables sliding window

Compose

Compose several transforms together. All transforms will be called sequentially,

RandomChoice

Apply a single transformation randomly picked from a list.

ConditionalChoice

Conditionally apply a single transformation based on whether a condition is met.

UnitDropout

Augmentation that randomly drops units from the sample. By default, the number

TriangleDistribution

Triangular distribution with a peak at mode_units, going from min_units to

RandomTimeScaling

<no summary>

RandomOutputSampler

<no summary>

RandomCrop

<no summary>

BinSpikes

Bin spike events into fixed-width time bins.

UnitFilter

Drop units based on the mask_fn given in the constructor.

UnitFilterById

Keep/drop units based on the keyword/regex given in the constructor.

Embedding

A simple extension of torch.nn.Embedding to allow more control over

InfiniteVocabEmbedding

Embedding layer with a vocabulary that can be extended. Vocabulary is saved along

RotaryTimeEmbedding

Rotary time/positional embedding layer. This module is designed to be used with

SinusoidalTimeEmbedding

Sinusoidal time/position embedding layer.

FeedForward

A feed-forward network with GEGLU activation.

RotaryCrossAttention

Cross-attention layer with rotary positional embeddings.

RotarySelfAttention

Self-attention layer with rotary positional embeddings.

MultitaskReadout

A module that performs multi-task linear readouts from output embeddings.

prepare_for_multitask_readout

<no summary>

Loss

Base class for losses. All losses should support an optional weights argument.

MSELoss

Base class for losses. All losses should support an optional weights argument.

CrossEntropyLoss

Base class for losses. All losses should support an optional weights argument.

MallowDistanceLoss

Base class for losses. All losses should support an optional weights argument.

POYO

POYOPlus

CalciumPOYOPlus

poyo_mp

<no summary>

DataType

Enum defining the possible data types.

ModalitySpec

Specification for a modality.

register_modality

Register a new modality specification in the global registry.

get_modality_by_id

Get a modality specification by its ID.

MODALITY_REGISTRY

dict() -> new empty dictionary

stitch

Pools values that share the same timestamp using mean or mode operations.

seed_everything

Sets random seed for reproducibility.

create_linspace_latent_tokens

Creates a sequence of latent tokens. Each token is defined by the

create_start_end_unit_tokens

Creates for each unit a start and end token. Each token is defined by the

resolve_weights_based_on_interval_membership

Determine weights for timestamps based on which intervals they fall within.

isin_interval

Check if timestamps are in any of the intervals in the Interval object.

prepare_for_readout

<no summary>

np_string_prefix

Adds a string prefix to each element of a numpy string array.

bin_spikes

Bins spikes into time bins of size bin_size. If the total time spanned by