chain#

torch_brain.data.chain(obj, allow_missing_keys=False)[source]#

Wrap an object to specify that it (or any of its members) should be stacked along the first dimension when batching. This approach is similar to PyTorch Geometric’s collate approach for graphs. This function will chain all the sequences in the batch into one large sequence. To track which sample from the batch an element of the sequence came from use track_batch().

Parameters:
  • obj – Can be tensors, numpy arrays, lists, tuples, or dictionaries.

  • allow_missing_keys (bool) – If set to True, this allows the wrapped dictionary to have inconsistent keys across the batch. If an instance is missing keys, the corresponding values will be skipped when collating. Defaults to False.