Source code for torch_brain.data.concat

from functools import reduce

import numpy as np

from .irregular_ts import IrregularTimeSeries


[docs] def concat(objs, sort=True): """Concatenates multiple time series objects into a single object. Args: objs: List of time series objects to concatenate. sort: Whether to sort the resulting time series by timestamps. Only applies to IrregularTimeSeries. Defaults to True. Returns: Union[IrregularTimeSeries, RegularTimeSeries]: The concatenated time series object. Raises: ValueError: If objects are not all of the same type or don't have matching keys. NotImplementedError: If concatenation is not implemented for the given object type. Example :: >>> from torch_brain.data import IrregularTimeSeries, Interval, concat >>> ts1 = IrregularTimeSeries( ... timestamps=[0.0, 1.0], ... values=[1.0, 2.0], ... domain="auto", ... ) >>> ts2 = IrregularTimeSeries( ... timestamps=[2.0, 3.0], ... values=[3.0, 4.0], ... domain="auto", ... ) >>> ts_concat = concat([ts1, ts2]) >>> ts_concat IrregularTimeSeries( timestamps=[4], values=[4] ) >>> ts_concat.timestamps array([0., 1., 2., 3.]) """ # check if all objects are of the same type obj_type = type(objs[0]) if any(not isinstance(obj, obj_type) for obj in objs): raise ValueError( f"All objects must be of the same type, got: {[type(obj) for obj in objs]}" ) if obj_type == IrregularTimeSeries: domain = reduce(lambda x, y: x | y, [obj.domain for obj in objs]) keys = objs[0].keys() timekeys = objs[0].timekeys() for obj in objs: if set(obj.keys()) != set(keys): raise ValueError( f"All objects must have the same keys, got {keys} and {obj.keys()}" ) if set(obj.timekeys()) != set(timekeys): raise ValueError( f"All objects must have the same timekeys, got {timekeys} and {obj.timekeys()}" ) obj_concat_dict = {} for k in keys: obj_concat_dict[k] = np.concatenate([getattr(obj, k) for obj in objs]) obj_concat = IrregularTimeSeries( **obj_concat_dict, timekeys=timekeys, domain=domain ) if sort: obj_concat.sort() else: raise NotImplementedError(f"Concatenation not implemented for type: {obj_type}") return obj_concat