Source code for torch_brain.transforms.output_sampler
importtorch
[docs]classRandomOutputSampler:def__init__(self,num_output_tokens):self.num_output_tokens=num_output_tokensdef__call__(self,data):out=data.behaviortimestamps=out.timestampsiflen(timestamps)<=self.num_output_tokens:returndata# sample from timestampsmask=torch.zeros(len(timestamps),dtype=bool)mask[torch.randperm(len(timestamps))[:self.num_output_tokens]]=Trueforkey,valueinout.__dict__.items():out.__dict__[key]=value[mask].clone()data.behavior=outreturndata