torch_brain.transforms

Compose

Compose several transforms together.

RandomChoice

Apply a single transformation randomly picked from a list.

ConditionalChoice

Conditionally apply a single transformation based on whether a condition is met.

UnitDropout

Randomly drop units from the data.units and data.spikes.

RandomTimeScaling

Randomly scales the time axis.

RandomOutputSampler

Randomly drops output samples.

class Compose(transforms)[source]

Bases: object

Compose several transforms together. All transforms will be called sequentially, in order, and must accept and return a single temporaldata.Data object, except the last transform, which can return any object.

Parameters:

transforms (list of callable) – list of transforms to compose.

class RandomChoice(transforms, p=None)[source]

Bases: object

Apply a single transformation randomly picked from a list.

Parameters:
  • transforms (List[Callable]) – list of transformations

  • p (list of floats, optional) – probability of each transform being picked. If p doesn’t sum to 1, it is automatically normalized. By default, all transforms have the same probability.

class ConditionalChoice(condition, true_transform, false_transform)[source]

Bases: object

Conditionally apply a single transformation based on whether a condition is met.

Parameters:
  • condition (Callable) – callable that takes a data object and returns a boolean

  • true_transform (Callable) – transformation to apply if the condition is met

  • false_transform (Callable) – transformation to apply if the condition is not met

class UnitDropout(field='spikes', reset_index=True, *args, **kwargs)[source]

Bases: object

Augmentation that randomly drops units from the sample. By default, the number of units to keep is sampled from a triangular distribution defined in TriangleDistribution.

This transform assumes that the data has a units object. It works for both IrregularTimeSeries and RegularTimeSeries. For the former, it will drop spikes from the units that are not kept. For the latter, it will drop the corresponding columns from the data.

Parameters:
  • field (str, optional) – Field to apply the dropout. Defaults to “spikes”.

  • *args – Arguments to pass to the TriangleDistribution constructor.

  • **kwargs

    Arguments to pass to the TriangleDistribution constructor.

class TriangleDistribution(min_units=20, mode_units=100, max_units=300, tail_right=None, peak=4, M=10, max_attempts=100, seed=None)[source]

Bases: object

Triangular distribution with a peak at mode_units, going from min_units to max_units.

The unnormalized density function is defined as:

\[\begin{split}f(x) = \begin{cases} 0 & \text{if } x < \text{min_units} \\ 1 + (\text{peak} - 1) \cdot \frac{x - \text{min_units}}{\text{mode_units} - \text{min_units}} & \text{if } \text{min_units} \leq x \leq \text{mode_units} \\ \text{peak} - (\text{peak} - 1) \cdot \frac{x - \text{mode_units}}{\text{tail_right} - \text{mode_units}} & \text{if } \text{mode_units} \leq x \leq \text{tail_right} \\ 1 & \text{if } \text{tail_right} \leq x \leq \text{max_units}\\ 0 & \text{otherwise} \end{cases}\end{split}\]
Parameters:
  • min_units (int) – Minimum number of units to sample. If the population has fewer units than this, all units will be kept.

  • mode_units (int) – Mode of the distribution.

  • max_units (int) – Maximum number of units to sample.

  • tail_right (int, optional) – Right tail of the distribution. If None, it is set to max_units.

  • peak (float, optional) – Height of the peak of the distribution.

  • M (float, optional) – Normalization constant for the proposal distribution.

  • max_attempts (int, optional) – Maximum number of attempts to sample from the distribution.

  • seed (int, optional) – Seed for the random number generator.

../_images/triangle_distribution.png

To sample from the distribution, we use rejection sampling. We sample from a uniform distribution between min_units and max_units and accept the sample with probability \(\frac{f(x)}{M \cdot q(x)}\), where \(q(x)\) is the proposal distribution.

unnormalized_density_function(x)[source]
proposal_distribution(x)[source]
sample(num_units)[source]
class RandomTimeScaling(min_scale, max_scale, min_offset=0, max_offset=0)[source]

Bases: object

class RandomOutputSampler(num_output_tokens)[source]

Bases: object