Source code for torch_brain.transforms.bin_spikes
from typing import Optional
import numpy as np
from temporaldata import Data, RegularTimeSeries
from torch_brain.utils.binning import bin_spikes
[docs]
class BinSpikes:
r"""Bin spike events into fixed-width time bins.
The transform reads spikes and units from nested attributes, applies
:func:`torch_brain.utils.binning.bin_spikes`, and stores the result in a new
nested attribute named ``{spikes_attribute}_binned``.
Args:
bin_size: Bin width in seconds.
spikes_attribute: Nested attribute path to the spikes ``IrregularTimeseries``.
units_attribute: Nested attribute path to the units ``ArrayDict``.
max_spikes: Maximum number of spikes to include per unit per
bin. If ``None``, no clipping is applied.
right: Decide which side gets truncated when duration is not
a multiple of ``bin_size``. If ``True``, excess spikes are truncated from the left edge.
eps: Small numerical margin used during bin assignment.
dtype: Data type of the output binned array. (default ``np.int32``)
"""
def __init__(
self,
bin_size: float,
spikes_attribute: str = "spikes",
units_attribute: str = "units",
max_spikes: Optional[int] = None,
right: bool = True,
eps: float = 1e-3,
dtype: np.dtype = np.int32,
):
self.spikes_attr = spikes_attribute
self.units_attr = units_attribute
self.params = {
"bin_size": bin_size,
"max_spikes": max_spikes,
"right": right,
"eps": eps,
"dtype": dtype,
}
def __call__(self, data: Data):
spikes = data.get_nested_attribute(self.spikes_attr)
units = data.get_nested_attribute(self.units_attr)
binned_counts = bin_spikes(spikes, num_units=len(units), **self.params)
# RegularTimeSeries expects time on axis 0; bin_spikes returns (units, bins).
binned_spikes = RegularTimeSeries(
sampling_rate=1 / self.params["bin_size"],
binned_counts=binned_counts,
domain="auto",
domain_start=spikes.domain.start[0],
)
# TODO switch to Data.set_nested_attribute() when released through temporaldata
_set_nested_attribute(data, f"{self.spikes_attr}_binned", binned_spikes)
return data
# TODO remove when Data.set_nested_attribute() is released through temporaldata
def _set_nested_attribute(data, path: str, value):
r"""Set a nested attribute in a :class:`temporaldata.Data` object using a dot-separated path.
Args:
data: The :class:`temporaldata.Data` object to modify.
path: The dot-separated path to the nested attribute (e.g., "session.id").
value: The value to set for the attribute.
Returns:
The modified data object (same instance, modified in-place).
Raises:
AttributeError: If any component of the path cannot be resolved.
"""
# Split key by dots, resolve using getattr
components = path.split(".")
obj = data
for c in components[:-1]:
try:
obj = getattr(obj, c)
except AttributeError:
raise AttributeError(
f"Could not resolve {path} in data (specifically, at level {c})"
)
setattr(obj, components[-1], value)
return data