DistributedStitchingFixedWindowSampler#
- class torch_brain.data.sampler.DistributedStitchingFixedWindowSampler(*, sampling_intervals, window_length, step=None, batch_size, num_replicas=None, rank=None)[source]#
Bases:
torch.utils.data.distributed.DistributedSamplerA sampler designed specifically for evaluation that enables sliding window inference with prediction stitching across distributed processes.
This sampler divides sequences into overlapping windows and distributes them across processes for parallel inference, it keeps windows that need to be stitched together on the same rank, to allow stitching on that same rank without communication.
Additionally, it will keep track of the windows that need to be stitched together to allow for stitching as soon as all windows from the same contiguous sequence are available. This information can be passed to the stitcher which can stitch and compute a metric for the sequence as soon as all windows from that sequence are available, allowing it to free up memory quickly.
- Parameters:
sampling_intervals (Dict[str, List[Tuple[int, int]]]) – Sampling intervals for each session in the dataset. Each interval is defined by a start and end time.
window_length (float) – Length of the sliding window.
step (Optional[float], optional) – Step size between windows. If None, defaults to window_length. Smaller steps create more overlap between windows.
batch_size (int) – Number of windows to process in each batch.
num_replicas (Optional[int], optional) – Number of processes participating in distributed inference. If None, will be set using torch.distributed.
rank (Optional[int], optional) – Rank of the current process. If None, will be set using torch.distributed.