Collate¶
An extended collate function that handles padding and chaining. |
|
A wrapper to call when padding. |
|
A wrapper to call to track the padding mask during padding. |
|
A wrapper to call when padding, but length is rounded up to the nearest multiple of 8. |
|
A wrapper to call to track the padding mask during padding with |
|
A wrapper to call when chaining. |
|
A wrapper to call to track the batch index during chaining. |
- collate(batch)[source]¶
Extension of PyTorch’s
default_collate
function to enable more advanced collation of samples of variable lengths.To specify how the collation recipe, wrap the objects using
pad
orchain
. If the wrapped object is an Iterable or Mapping, all its elements will inherite the collation recipe. All objects that are already supported bydefault_collate
can be wrapped.If an object is not wrapped, the default collation recipe will be used. i.e. the outcome will be identical to
default_collate
.pad
orchain
do not track any padding masks or batch index, since that might not always be needed. Usetrack_mask
ortrack_batch
to track masks or batch index for a particular array or tensor.- Parameters:
batch – a single batch to be collated
- pad(obj)[source]¶
Wrap an object to specify that it (or any of its members) should be padded to the maximum length in the batch. The object can be any of the objects that PyTorch’s
default_collate
already supports.- Parameters:
obj – Can be tensors, numpy arrays, lists, tuples, or dictionaries.
- track_mask(input)[source]¶
Wrap an array or tensor to specify that its padding mask should be tracked.
- pad8(obj)[source]¶
Wrap an object to specify that it (or any of its members) should be padded to the maximum length in the batch. This function is similar to
pad
except that the padding length is rounded up to the nearest multiple of 8.- Parameters:
obj – Can be tensors, numpy arrays, lists, tuples, or dictionaries.
- track_mask8(input)[source]¶
Wrap an array or tensor to specify that its padding mask should be tracked. This is used in conjunction with
pad8
.
- chain(obj, allow_missing_keys=False)[source]¶
Wrap an object to specify that it (or any of its members) should be stacked along the first dimension when batching. This approach is similar to PyTorch Geometric’s collate approach for graphs. This function will chain all the sequences in the batch into one large sequence. To track which sample from the batch an element of the sequence came from use
track_batch()
.- Parameters: