Collate

collate

An extended collate function that handles padding and chaining.

pad

A wrapper to call when padding.

track_mask

A wrapper to call to track the padding mask during padding.

pad8

A wrapper to call when padding, but length is rounded up to the nearest multiple of 8.

track_mask8

A wrapper to call to track the padding mask during padding with pad8.

chain

A wrapper to call when chaining.

track_batch

A wrapper to call to track the batch index during chaining.

collate(batch)[source]

Extension of PyTorch’s default_collate function to enable more advanced collation of samples of variable lengths.

To specify how the collation recipe, wrap the objects using pad or chain. If the wrapped object is an Iterable or Mapping, all its elements will inherite the collation recipe. All objects that are already supported by default_collate can be wrapped.

If an object is not wrapped, the default collation recipe will be used. i.e. the outcome will be identical to default_collate.

pad or chain do not track any padding masks or batch index, since that might not always be needed. Use track_mask or track_batch to track masks or batch index for a particular array or tensor.

Parameters:

batch – a single batch to be collated

pad(obj)[source]

Wrap an object to specify that it (or any of its members) should be padded to the maximum length in the batch. The object can be any of the objects that PyTorch’s default_collate already supports.

Parameters:

obj – Can be tensors, numpy arrays, lists, tuples, or dictionaries.

track_mask(input)[source]

Wrap an array or tensor to specify that its padding mask should be tracked.

Parameters:

input (Union[Tensor, ndarray]) – An array or tensor.

pad8(obj)[source]

Wrap an object to specify that it (or any of its members) should be padded to the maximum length in the batch. This function is similar to pad except that the padding length is rounded up to the nearest multiple of 8.

Parameters:

obj – Can be tensors, numpy arrays, lists, tuples, or dictionaries.

track_mask8(input)[source]

Wrap an array or tensor to specify that its padding mask should be tracked. This is used in conjunction with pad8.

Parameters:

input (Union[Tensor, ndarray]) – An array or tensor.

chain(obj, allow_missing_keys=False)[source]

Wrap an object to specify that it (or any of its members) should be stacked along the first dimension when batching. This approach is similar to PyTorch Geometric’s collate approach for graphs. This function will chain all the sequences in the batch into one large sequence. To track which sample from the batch an element of the sequence came from use track_batch().

Parameters:
  • obj – Can be tensors, numpy arrays, lists, tuples, or dictionaries.

  • allow_missing_keys (bool) – If set to True, this allows the wrapped dictionary to have inconsistent keys across the batch. If an instance is missing keys, the corresponding values will be skipped when collating. Defaults to False.

track_batch(input)[source]

Wrap an array or tensor to track the batch_index.

Parameters:

input (Union[Tensor, ndarray]) – An array or tensor.