Source code for torch_brain.transforms.random_time_scaling
import copy
import torch
from temporaldata import IrregularTimeSeries, RegularTimeSeries, Interval, Data
def rescale(data: Data, scale: float, offset: float):
r"""Rescale the time axis of the data by a factor and offset.
Args:
data (Data): The data to rescale.
scale (float): The scaling factor.
offset (float): The offset.
"""
out = data.__class__.__new__(data.__class__)
for key, value in data.__dict__.items():
# todo update domain
if key != "_domain" and isinstance(value, IrregularTimeSeries):
val = copy.copy(value)
val.timestamps = val.timestamps * scale + offset
val._domain = copy.copy(value._domain)
val._domain.start = val._domain.start * scale + offset
val._domain.end = val._domain.end * scale + offset
out.__dict__[key] = val
elif key != "_domain" and isinstance(value, RegularTimeSeries):
val = copy.copy(value)
val._sampling_rate = val._sampling_rate / scale
val._domain = copy.copy(value._domain)
val._domain.start = val._domain.start * scale + offset
val._domain.end = val._domain.end * scale + offset
out.__dict__[key] = val
elif key != "_domain" and isinstance(value, Interval):
val = copy.copy(value)
val.start = val.start * scale + offset
val.end = val.end * scale + offset
out.__dict__[key] = val
else:
out.__dict__[key] = copy.copy(value)
# update domain
out._domain = copy.copy(data._domain)
out._domain.start = out._domain.start * scale + offset
out._domain.end = out._domain.end * scale + offset
# update slice start time
out._absolute_start = data._absolute_start
return out
[docs]
class RandomTimeScaling:
def __init__(self, min_scale, max_scale, min_offset=0, max_offset=0):
self.min_scale = min_scale
self.max_scale = max_scale
self.min_offset = min_offset
self.max_offset = max_offset
def __call__(self, data):
scale = (
torch.rand(1).item() * (self.max_scale - self.min_scale) + self.min_scale
)
offset = (
torch.rand(1).item() * (self.max_offset - self.min_offset) + self.min_offset
)
return rescale(data, scale, offset)