torch_brain.data#

torch_brain.data.collate#

chain

Wrap an object to specify that it (or any of its members) should be stacked along the first dimension when batching.

collate

Extension of PyTorch's default_collate function to enable more advanced collation of samples of variable lengths.

pad

Wrap an object to specify that it (or any of its members) should be padded to the maximum length in the batch.

pad8

Wrap an object to specify that it (or any of its members) should be padded to the maximum length in the batch.

pad2d

pad2d8

track_batch

Wrap an array or tensor to track the batch_index.

track_mask

Wrap an array or tensor to specify that its padding mask should be tracked.

track_mask8

Wrap an array or tensor to specify that its padding mask should be tracked.

track_mask2d

Wrap an array or tensor to specify that its padding mask should be tracked.

track_mask2d8

Wrap an array or tensor to specify that its padding mask should be tracked.

torch_brain.data.dataset [DEPRECATED]#

Dataset

Deprecated.