resolve_weights_based_on_interval_membership#
- torch_brain.utils.resolve_weights_based_on_interval_membership(timestamps, data, config=None)[source]#
Determine weights for timestamps based on which intervals they fall within. The intervals and corresponding weights are specified in the config dictionary.
The config dictionary maps interval names (nested notation allowed) to weight values. For example:
{ 'movement_periods.random_period': 1.0, 'movement_periods.hold_period': 0.1, 'movement_periods.reach_period': 5.0, 'movement_periods.return_period': 1.0, 'cursor_outlier_segments': 0.0, }
These weights can be used to weight different time periods differently in the loss function. In the example above, reach periods are weighted 5x more heavily than random periods.
Note
If intervals overlap, the final weight will be the product of all weights from those intervals. For example, if a timestamp falls within both a reach_period (weight 5.0) and a cursor_outlier_segments (weight 0.0), its final weight will be 5.0 * 0.0 = 0.0. This multiplicative behavior allows for complex weighting schemes where other intervals can be combined.
Note
If a timestamp does not belong to any of the intervals in the config, its weight will remain at the default value of 1.0.
- Parameters:
timestamps – Array of timestamps
data – Data object containing intervals
config – Dictionary mapping interval names to weight values
- Returns:
Array of weights with same shape as timestamps