Samplers¶
A Sequential sampler, that samples a fixed-length window from data. |
|
A Random sampler, that samples a fixed-length window from data. |
|
A sampler that randomly samples a single trial interval from given intervals. |
|
|
A wrapper sampler for distributed training that assigns samples to processes. |
A distributed sampler for evaluation that enables sliding window inference with prediction stitching. |
- class SequentialFixedWindowSampler(*, sampling_intervals, window_length, step=None, drop_short=False)[source]¶
Bases:
Sampler
Samples fixed-length windows sequentially, always in the same order. The sampling intervals are defined in the
sampling_intervals
parameter.sampling_intervals
is a dictionary where the keys are the session ids and the values are lists of tuples representing the start and end of the intervals from which to sample.If the length of a sequence is not evenly divisible by the step, the last window will be added with an overlap with the previous window. This is to ensure that the entire sequence is covered.
- Parameters:
sampling_intervals (Dict[str, List[Tuple[int, int]]]) – Sampling intervals for each session in the dataset.
window_length (float) – Length of the window to sample.
step (Optional[float], optional) – Step size between windows. If None, it defaults to window_length. Defaults to None.
drop_short (bool, optional) – Whether to drop short intervals. Defaults to False.
- class RandomFixedWindowSampler(*, sampling_intervals, window_length, generator=None, drop_short=True)[source]¶
Bases:
Sampler
Samples fixed-length windows randomly, given intervals defined in the
sampling_intervals
parameter.sampling_intervals
is a dictionary where the keys are the session ids and the values are lists of tuples representing the start and end of the intervals from which to sample. The samples are shuffled, and random temporal jitter is applied.In one epoch, the number of samples that is generated from a given sampling interval is given by:
\[N = \left\lfloor\frac{\text{interval_length}}{\text{window_length}}\right\rfloor\]- Parameters:
sampling_intervals (Dict[str, List[Tuple[int, int]]]) – Sampling intervals for each session in the dataset.
window_length (float) – Length of the window to sample.
generator (Optional[torch.Generator], optional) – Generator for shuffling. Defaults to None.
drop_short (bool, optional) – Whether to drop short intervals. Defaults to True.
- class TrialSampler(*, sampling_intervals, generator=None, shuffle=False)[source]¶
Bases:
Sampler
Randomly samples a single trial interval from the given intervals.
- class DistributedStitchingFixedWindowSampler(*, sampling_intervals, window_length, step=None, batch_size, num_replicas=None, rank=None)[source]¶
Bases:
DistributedSampler
A sampler designed specifically for evaluation that enables sliding window inference with prediction stitching across distributed processes.
This sampler divides sequences into overlapping windows and distributes them across processes for parallel inference, it keeps windows that need to be stitched together on the same rank, to allow stitching on that same rank without communication.
Additionally, it will keep track of the windows that need to be stitched together to allow for stitching as soon as all windows from the same contiguous sequence are available. This information can be passed to the stitcher which can stitch and compute a metric for the sequence as soon as all windows from that sequence are available, allowing it to free up memory quickly.
- Parameters:
sampling_intervals (Dict[str, List[Tuple[int, int]]]) – Sampling intervals for each session in the dataset. Each interval is defined by a start and end time.
window_length (float) – Length of the sliding window.
step (Optional[float], optional) – Step size between windows. If None, defaults to window_length. Smaller steps create more overlap between windows.
batch_size (int) – Number of windows to process in each batch.
num_replicas (Optional[int], optional) – Number of processes participating in distributed inference. If None, will be set using torch.distributed.
rank (Optional[int], optional) – Rank of the current process. If None, will be set using torch.distributed.