[docs]classCompose:r"""Compose several transforms together. All transforms will be called sequentially, in order, and must accept and return a single :obj:`temporaldata.Data` object, except the last transform, which can return any object. Args: transforms (list of callable): list of transforms to compose. """def__init__(self,transforms:List[Callable]):self.transforms=transformsdef__call__(self,data:temporaldata.Data)->temporaldata.Data:fortransforminself.transforms:data=transform(data)returndata
# similar to torchvision.transforms.v2.RandomChoice
[docs]classRandomChoice:r"""Apply a single transformation randomly picked from a list. Args: transforms: list of transformations p (list of floats, optional): probability of each transform being picked. If :obj:`p` doesn't sum to 1, it is automatically normalized. By default, all transforms have the same probability. """def__init__(self,transforms:List[Callable],p:List[float]=None)->None:ifpisNone:p=[1]*len(transforms)eliflen(p)!=len(transforms):raiseValueError(f"Length of p doesn't match the number of transforms: "f"{len(p)} != {len(transforms)}")super().__init__()self.transforms=transformstotal=sum(p)self.p=[prob/totalforprobinp]def__call__(self,data:temporaldata.Data)->temporaldata.Data:idx=np.random.choice(len(self.transforms),p=self.p)transform=self.transforms[idx]returntransform(data)
# args similar to jax.lax.cond
[docs]classConditionalChoice:r"""Conditionally apply a single transformation based on whether a condition is met. Args: condition: callable that takes a data object and returns a boolean true_transform: transformation to apply if the condition is met false_transform: transformation to apply if the condition is not met """def__init__(self,condition:Callable,true_transform:Callable,false_transform:Callable)->None:self.condition=conditionself.true_transform=true_transformself.false_transform=false_transformdef__call__(self,data:temporaldata.Data)->temporaldata.Data:ret=self.condition(data)ifnotisinstance(ret,bool):raiseValueError(f"Condition must return a boolean, got {type(ret)} instead.")ifret:returnself.true_transform(data)else:returnself.false_transform(data)