API Reference#

Look through specific module or search through the entire list of APIs.

All APIs

Object

Description

Dataset

PyTorch Dataset for loading time-slices of neural data recordings from HDF5 files.

DatasetIndex

Index for accessing a specific time interval of a recording within a Dataset.

NestedDataset

Dataset that composes multiple Dataset instances under a single interface.

SpikingDatasetMixin

Mixin class for torch_brain.dataset.Dataset subclasses containing spiking data.

CalciumImagingDatasetMixin

Mixin class for torch_brain.dataset.Dataset subclasses containing calcium imaging data.

MultiChannelDatasetMixin

Mixin class for torch_brain.dataset.Dataset subclasses containing

RandomFixedWindowSampler

Samples fixed-length windows randomly from a collection of time intervals.

SequentialFixedWindowSampler

Samples fixed-length windows sequentially in a deterministic, reproducible order.

TrialSampler

Samples complete trial intervals without windowing.

DistributedEvaluationSamplerWrapper

Wraps any sampler for use in distributed evaluation without dropping samples.

DistributedStitchingFixedWindowSampler

Distributed sliding-window sampler that co-locates windows for prediction stitching.

Compose

Compose several transforms together. All transforms will be called sequentially,

RandomChoice

Apply a single transformation randomly picked from a list.

ConditionalChoice

Conditionally apply a single transformation based on whether a condition is met.

UnitDropout

Augmentation that randomly drops units from the sample. By default, the number

TriangleDistribution

Triangular distribution with a peak at mode_units, going from min_units to

RandomTimeScaling

<no summary>

RandomOutputSampler

<no summary>

RandomCrop

<no summary>

BinSpikes

Bin spike events into fixed-width time bins.

UnitFilter

Drop units based on the mask_fn given in the constructor.

UnitFilterById

Keep/drop units based on the keyword/regex given in the constructor.

chain

Wrap an object to specify that it (or any of its members) should be stacked

collate

Extension of PyTorch’s default_collate function to enable more advanced

pad

Wrap an object to specify that it (or any of its members) should be padded to

pad8

Wrap an object to specify that it (or any of its members) should be padded to

pad2d

pad2d8

track_batch

Wrap an array or tensor to track the batch_index.

track_mask

Wrap an array or tensor to specify that its padding mask should be tracked.

track_mask8

Wrap an array or tensor to specify that its padding mask should be tracked. This

track_mask2d

Wrap an array or tensor to specify that its padding mask should be tracked. This

track_mask2d8

Wrap an array or tensor to specify that its padding mask should be tracked. This

Embedding

A simple extension of torch.nn.Embedding to allow more control over

InfiniteVocabEmbedding

Embedding layer with a vocabulary that can be extended. Vocabulary is saved along

RotaryTimeEmbedding

Rotary time/positional embedding layer. This module is designed to be used with

SinusoidalTimeEmbedding

Sinusoidal time/position embedding layer.

RotaryCrossAttention

Cross-attention layer with rotary positional embeddings.

RotarySelfAttention

Self-attention layer with rotary positional embeddings.

POYO

stitch

Pools values that share the same timestamp using mean or mode operations.

create_linspace_latent_tokens

Creates a sequence of latent tokens. Each token is defined by the

create_start_end_unit_tokens

Creates for each unit a start and end token. Each token is defined by the

resolve_weights_based_on_interval_membership

Determine weights for timestamps based on which intervals they fall within.

isin_interval

Check if timestamps are in any of the intervals in the Interval object.

np_string_prefix

Adds a string prefix to each element of a numpy string array.

bin_spikes

Bins spikes into time bins of size bin_size. If the total time spanned by

Dataset