Source code for torch_brain.transforms.container
from typing import Any, Callable, List
import numpy as np
import temporaldata
[docs]
class Compose:
r"""Compose several transforms together. All transforms will be called sequentially,
in order, and must accept and return a single :obj:`temporaldata.Data` object, except
the last transform, which can return any object.
Args:
transforms (list of callable): list of transforms to compose.
"""
def __init__(self, transforms: List[Callable]):
self.transforms = transforms
def __call__(self, data: temporaldata.Data) -> temporaldata.Data:
for transform in self.transforms:
data = transform(data)
return data
# similar to torchvision.transforms.v2.RandomChoice
[docs]
class RandomChoice:
r"""Apply a single transformation randomly picked from a list.
Args:
transforms: list of transformations
p (list of floats, optional): probability of each transform being picked.
If :obj:`p` doesn't sum to 1, it is automatically normalized. By default,
all transforms have the same probability.
"""
def __init__(self, transforms: List[Callable], p: List[float] = None) -> None:
if p is None:
p = [1] * len(transforms)
elif len(p) != len(transforms):
raise ValueError(
f"Length of p doesn't match the number of transforms: "
f"{len(p)} != {len(transforms)}"
)
super().__init__()
self.transforms = transforms
total = sum(p)
self.p = [prob / total for prob in p]
def __call__(self, data: temporaldata.Data) -> temporaldata.Data:
idx = np.random.choice(len(self.transforms), p=self.p)
transform = self.transforms[idx]
return transform(data)
# args similar to jax.lax.cond
[docs]
class ConditionalChoice:
r"""Conditionally apply a single transformation based on whether a condition is met.
Args:
condition: callable that takes a data object and returns a boolean
true_transform: transformation to apply if the condition is met
false_transform: transformation to apply if the condition is not met
"""
def __init__(
self, condition: Callable, true_transform: Callable, false_transform: Callable
) -> None:
self.condition = condition
self.true_transform = true_transform
self.false_transform = false_transform
def __call__(self, data: temporaldata.Data) -> temporaldata.Data:
ret = self.condition(data)
if not isinstance(ret, bool):
raise ValueError(
f"Condition must return a boolean, got {type(ret)} instead."
)
if ret:
return self.true_transform(data)
else:
return self.false_transform(data)