Source code for torch_brain.transforms.output_sampler
import torch
[docs]
class RandomOutputSampler:
def __init__(self, num_output_tokens):
self.num_output_tokens = num_output_tokens
def __call__(self, data):
out = data.behavior
timestamps = out.timestamps
if len(timestamps) <= self.num_output_tokens:
return data
# sample from timestamps
mask = torch.zeros(len(timestamps), dtype=bool)
mask[torch.randperm(len(timestamps))[: self.num_output_tokens]] = True
for key, value in out.__dict__.items():
out.__dict__[key] = value[mask].clone()
data.behavior = out
return data