Source code for torch_brain.data.concat
from functools import reduce
import numpy as np
from .irregular_ts import IrregularTimeSeries
[docs]
def concat(objs, sort=True):
"""Concatenates multiple time series objects into a single object.
Args:
objs: List of time series objects to concatenate.
sort: Whether to sort the resulting time series by timestamps. Only applies to IrregularTimeSeries. Defaults to True.
Returns:
Union[IrregularTimeSeries, RegularTimeSeries]: The concatenated time series object.
Raises:
ValueError: If objects are not all of the same type or don't have matching keys.
NotImplementedError: If concatenation is not implemented for the given object type.
Example ::
>>> from torch_brain.data import IrregularTimeSeries, Interval, concat
>>> ts1 = IrregularTimeSeries(
... timestamps=[0.0, 1.0],
... values=[1.0, 2.0],
... domain="auto",
... )
>>> ts2 = IrregularTimeSeries(
... timestamps=[2.0, 3.0],
... values=[3.0, 4.0],
... domain="auto",
... )
>>> ts_concat = concat([ts1, ts2])
>>> ts_concat
IrregularTimeSeries(
timestamps=[4],
values=[4]
)
>>> ts_concat.timestamps
array([0., 1., 2., 3.])
"""
# check if all objects are of the same type
obj_type = type(objs[0])
if any(not isinstance(obj, obj_type) for obj in objs):
raise ValueError(
f"All objects must be of the same type, got: {[type(obj) for obj in objs]}"
)
if obj_type == IrregularTimeSeries:
domain = reduce(lambda x, y: x | y, [obj.domain for obj in objs])
keys = objs[0].keys()
timekeys = objs[0].timekeys()
for obj in objs:
if set(obj.keys()) != set(keys):
raise ValueError(
f"All objects must have the same keys, got {keys} and {obj.keys()}"
)
if set(obj.timekeys()) != set(timekeys):
raise ValueError(
f"All objects must have the same timekeys, got {timekeys} and {obj.timekeys()}"
)
obj_concat_dict = {}
for k in keys:
obj_concat_dict[k] = np.concatenate([getattr(obj, k) for obj in objs])
obj_concat = IrregularTimeSeries(
**obj_concat_dict, timekeys=timekeys, domain=domain
)
if sort:
obj_concat.sort()
else:
raise NotImplementedError(f"Concatenation not implemented for type: {obj_type}")
return obj_concat