resolve_weights_based_on_interval_membership#

torch_brain.utils.resolve_weights_based_on_interval_membership(timestamps, data, config=None)[source]#

Determine weights for timestamps based on which intervals they fall within. The intervals and corresponding weights are specified in the config dictionary.

The config dictionary maps interval names (nested notation allowed) to weight values. For example:

{
    'movement_periods.random_period': 1.0,
    'movement_periods.hold_period': 0.1,
    'movement_periods.reach_period': 5.0,
    'movement_periods.return_period': 1.0,
    'cursor_outlier_segments': 0.0,
}

These weights can be used to weight different time periods differently in the loss function. In the example above, reach periods are weighted 5x more heavily than random periods.

Note

If intervals overlap, the final weight will be the product of all weights from those intervals. For example, if a timestamp falls within both a reach_period (weight 5.0) and a cursor_outlier_segments (weight 0.0), its final weight will be 5.0 * 0.0 = 0.0. This multiplicative behavior allows for complex weighting schemes where other intervals can be combined.

Note

If a timestamp does not belong to any of the intervals in the config, its weight will remain at the default value of 1.0.

Parameters:
  • timestamps – Array of timestamps

  • data – Data object containing intervals

  • config – Dictionary mapping interval names to weight values

Returns:

Array of weights with same shape as timestamps