Source code for torch_brain.transforms.unit_dropout

import logging
from typing import Optional

import numpy as np

from temporaldata import Data, IrregularTimeSeries, RegularTimeSeries


[docs] class TriangleDistribution: r"""Triangular distribution with a peak at mode_units, going from min_units to max_units. The unnormalized density function is defined as: .. math:: f(x) = \begin{cases} 0 & \text{if } x < \text{min_units} \\ 1 + (\text{peak} - 1) \cdot \frac{x - \text{min_units}}{\text{mode_units} - \text{min_units}} & \text{if } \text{min_units} \leq x \leq \text{mode_units} \\ \text{peak} - (\text{peak} - 1) \cdot \frac{x - \text{mode_units}}{\text{tail_right} - \text{mode_units}} & \text{if } \text{mode_units} \leq x \leq \text{tail_right} \\ 1 & \text{if } \text{tail_right} \leq x \leq \text{max_units}\\ 0 & \text{otherwise} \end{cases} Args: min_units (int): Minimum number of units to sample. If the population has fewer units than this, all units will be kept. mode_units (int): Mode of the distribution. max_units (int): Maximum number of units to sample. tail_right (int, optional): Right tail of the distribution. If None, it is set to `max_units`. peak (float, optional): Height of the peak of the distribution. M (float, optional): Normalization constant for the proposal distribution. max_attempts (int, optional): Maximum number of attempts to sample from the distribution. seed (int, optional): Seed for the random number generator. .. image:: /_static/img/triangle_distribution.png To sample from the distribution, we use rejection sampling. We sample from a uniform distribution between `min_units` and `max_units` and accept the sample with probability :math:`\frac{f(x)}{M \cdot q(x)}`, where :math:`q(x)` is the proposal distribution. """ def __init__( self, min_units: int = 20, mode_units: int = 100, max_units: int = 300, tail_right: Optional[int] = None, peak: float = 4, M: int = 10, max_attempts: int = 100, seed: Optional[int] = None, ): super().__init__() self.min_units = min_units self.mode_units = mode_units self.max_units = max_units self.tail_right = tail_right if tail_right is not None else max_units self.peak = peak self.M = M self.max_attempts = max_attempts # TODO pass a generator? self.rng = np.random.default_rng(seed=seed) def unnormalized_density_function(self, x): if x < self.min_units: return 0 if x <= self.mode_units: return 1 + (self.peak - 1) * (x - self.min_units) / ( self.mode_units - self.min_units ) if x <= self.tail_right: return self.peak - (self.peak - 1) * (x - self.mode_units) / ( self.tail_right - self.mode_units ) return 1 def proposal_distribution(self, x): return self.rng.uniform() def sample(self, num_units): if num_units < self.min_units: return num_units # uses rejection sampling num_attempts = 0 while True: x = self.min_units + self.rng.uniform() * ( self.max_units - self.min_units ) # Sample from the proposal distribution u = self.rng.uniform() if u <= self.unnormalized_density_function(x) / ( self.M * self.proposal_distribution(x) ): return x num_attempts += 1 if num_attempts > self.max_attempts: logging.warning( f"Could not sample from distribution after {num_attempts} attempts," " using all units." ) return num_units
[docs] class UnitDropout: r"""Augmentation that randomly drops units from the sample. By default, the number of units to keep is sampled from a triangular distribution defined in :class:`TriangleDistribution`. This transform assumes that the data has a `units` object. It works for both :class:`IrregularTimeSeries` and :class:`RegularTimeSeries`. For the former, it will drop spikes from the units that are not kept. For the latter, it will drop the corresponding columns from the data. Args: field (str, optional): Field to apply the dropout. Defaults to "spikes". \*args, \*\*kwargs: Arguments to pass to the :class:`TriangleDistribution` constructor. """ def __init__(self, field: str = "spikes", reset_index=True, *args, **kwargs): # TODO allow multiple fields (example: spikes + LFP) self.field = field self.reset_index = reset_index # TODO this currently assumes the type of distribution we use, in the future, # the distribution might be passed as an argument. self.distribution = TriangleDistribution(*args, **kwargs) def __call__(self, data: Data): # get units from data unit_ids = data.units.id num_units = len(unit_ids) # sample the number of units to keep from the population num_units_to_sample = int(self.distribution.sample(num_units)) # shuffle units and take the first num_units_to_sample keep_indices = np.random.permutation(num_units)[:num_units_to_sample] unit_mask = np.zeros_like(unit_ids, dtype=bool) unit_mask[keep_indices] = True if self.reset_index: data.units = data.units.select_by_mask(unit_mask) nested_attr = self.field.split(".") target_obj = getattr(data, nested_attr[0]) if isinstance(target_obj, IrregularTimeSeries): # make a mask to select spikes that are from the units we want to keep spike_mask = np.isin(target_obj.unit_index, keep_indices) # using lazy masking, we will apply the mask for all attributes from spikes # and units. setattr(data, self.field, target_obj.select_by_mask(spike_mask)) if self.reset_index: relabel_map = np.zeros(num_units, dtype=int) relabel_map[unit_mask] = np.arange(unit_mask.sum()) target_obj = getattr(data, self.field) target_obj.unit_index = relabel_map[target_obj.unit_index] elif isinstance(target_obj, RegularTimeSeries): assert len(nested_attr) == 2 setattr( target_obj, nested_attr[1], getattr(target_obj, nested_attr[1])[:, unit_mask], ) else: raise ValueError(f"Unsupported type for {self.field}: {type(target_obj)}") return data