stitch#

torch_brain.utils.stitch(timestamps, values)[source]#

Pools values that share the same timestamp using mean or mode operations.

This function is useful when you have multiple predictions or values for the same timestamp (e.g., from overlapping windows) and need to combine them into a single value per timestamp.

Parameters:
  • timestamps (torch.Tensor) – A 1D tensor containing timestamps. Shape: (N,).

  • values (torch.Tensor) – A tensor of values corresponding to the timestamps. Shape (N, ...) for floating point types, or (N,) for categorical types (torch.long only).

Returns:

A tuple (unique_timestamps, pooled_values). unique_timestamps is a 1D tensor of sorted unique timestamps. pooled_values contains the pooled values for each unique timestamp — mean pooling for continuous (float) data, mode pooling for categorical (torch.long) data.

Return type:

tuple[torch.Tensor, torch.Tensor]

Examples

>>> # Mean pooling for continuous values
>>> timestamps = torch.tensor([1, 1, 2, 3, 3])
>>> values = torch.tensor([0.1, 0.3, 0.2, 0.4, 0.6])
>>> stitch(timestamps, values)
(tensor([1, 2, 3]), tensor([0.2000, 0.2000, 0.5000]))
>>> # Mode pooling for categorical values
>>> timestamps = torch.tensor([1, 1, 2, 3, 3, 3])
>>> values = torch.tensor([1, 1, 2, 3, 3, 1], dtype=torch.long)
>>> stitch(timestamps, values)
(tensor([1, 2, 3]), tensor([1, 2, 3]))