collate#

torch_brain.data.collate(batch)[source]#

Extension of PyTorch’s default_collate function to enable more advanced collation of samples of variable lengths.

To specify the collation recipe, wrap the objects using pad or chain. If the wrapped object is an Iterable or Mapping, all its elements will inherit the collation recipe. All objects that are already supported by default_collate can be wrapped.

If an object is not wrapped, the default collation recipe will be used, i.e., the outcome will be identical to default_collate.

pad or chain do not track any padding masks or batch index, since that might not always be needed. Use track_mask or track_batch to track masks or batch index for a particular array or tensor.

Parameters:

batch – a single batch to be collated