Source code for torch_brain.utils.binning
from typing import Optional
import numpy as np
from temporaldata import IrregularTimeSeries
[docs]
def bin_spikes(
spikes: IrregularTimeSeries,
num_units: int,
bin_size: float,
max_spikes: Optional[int] = None,
right: bool = True,
eps: float = 1e-3,
dtype: np.dtype = np.int32,
) -> np.ndarray:
r"""Bins spikes into time bins of size `bin_size`. If the total time spanned by
the spikes is not a multiple of `bin_size`, the spikes are truncated to the nearest
multiple of `bin_size`. If `right` is True, the spikes are truncated from the left
end of the time series, otherwise they are truncated from the right end.
Notes:
- The number of units cannot be inferred from a subset of spikes,
so ``num_units`` must be provided explicitly.
- Floating-point roundoff can cause ``(end - start) / bin_size`` to be
very close to an integer without being exact (e.g. 9.99999999).
The ``eps`` parameter is added before flooring to make the bin-count
computation numerically robust.
Args:
spikes: IrregularTimeSeries object containing the spikes.
num_units: Number of units in the population.
bin_size: Size of the time bins in seconds.
max_spikes: Maximum number of spikes to include per unit per
bin. If ``None``, no clipping is applied.
right: Decide which side gets truncated when duration is not
a multiple of ``bin_size``. If ``True``, excess spikes are truncated from the left edge.
eps: Small numerical margin used during bin assignment.
dtype: Data type of the output binned array. (default ``np.int32``)
Returns:
Binned spike counts with shape ``(T, N)``, where ``T`` is the number of
time bins and ``N`` is ``num_units``.
"""
start = spikes.domain.start[0]
end = spikes.domain.end[-1]
# Compute how much time must be discarded so that the duration
# is an exact multiple of `bin_size`. The epsilon stabilizes
# the floor operation under floating-point roundoff.
discard = (end - start) - np.floor(((end - start) / bin_size) + eps) * bin_size
# In theory, `discard` should always be non-negative.
# Floating-point roundoff may make it slightly negative,
# in that case, we avoid reslicing to prevent dropping the last spike.
if discard > 0:
if right:
start += discard
else:
end -= discard
# reslice
spikes = spikes.slice(start, end)
num_bins = round((end - start) / bin_size)
rate = 1 / bin_size # avoid precision issues
binned_spikes = np.zeros((num_bins, num_units), dtype=dtype)
# Handle timestamps when the domain start is non-zero
ts = spikes.timestamps - spikes.domain.start[0]
bin_index = np.floor(ts * rate).astype(int)
np.add.at(binned_spikes, (bin_index, spikes.unit_index), 1)
if max_spikes is not None:
np.clip(binned_spikes, None, max_spikes, out=binned_spikes)
return binned_spikes