Source code for torch_brain.data.collate

import copy
from collections import namedtuple
from typing import Callable, Dict, List, Optional, Tuple, Type, Union

import numpy as np
import torch
from torch.utils.data._utils.collate import collate as _collate
from torch.utils.data._utils.collate import default_collate_fn_map

# pad
PaddedObject = namedtuple("PaddedObject", ["obj"])


[docs] def pad(obj): r"""Wrap an object to specify that it (or any of its members) should be padded to the maximum length in the batch. The object can be any of the objects that PyTorch's :obj:`default_collate` already supports. Args: obj: Can be tensors, numpy arrays, lists, tuples, or dictionaries. """ return PaddedObject(obj)
[docs] def track_mask(input: Union[torch.Tensor, np.ndarray]): r"""Wrap an array or tensor to specify that its padding mask should be tracked. Args: input: An array or tensor. """ return pad(torch.ones((len(input)), dtype=torch.bool))
def pad_collate_tensor_fn( batch, *, collate_fn_map: Optional[Dict[Union[Type, Tuple[Type, ...]], Callable]] = None, ): # todo this will be more optimal than any code we'll write? it's in C++ return torch.nn.utils.rnn.pad_sequence(batch, batch_first=True, padding_value=0) pad_collate_fn_map = copy.deepcopy(default_collate_fn_map) pad_collate_fn_map[torch.Tensor] = pad_collate_tensor_fn def pad_collate_object_fn( batch, *, collate_fn_map: Optional[Dict[Union[Type, Tuple[Type, ...]], Callable]] = None, ): return _collate([e.obj for e in batch], collate_fn_map=pad_collate_fn_map) # pad8 Padded8Object = namedtuple("Padded8Object", ["obj"])
[docs] def pad8(obj): r"""Wrap an object to specify that it (or any of its members) should be padded to the maximum length in the batch. This function is similar to :obj:`pad` except that the padding length is rounded up to the nearest multiple of 8. Args: obj: Can be tensors, numpy arrays, lists, tuples, or dictionaries. """ return Padded8Object(obj)
[docs] def track_mask8(input: Union[torch.Tensor, np.ndarray]): r"""Wrap an array or tensor to specify that its padding mask should be tracked. This is used in conjunction with :obj:`pad8`. Args: input: An array or tensor. """ return pad8(torch.ones((len(input)), dtype=torch.bool))
def pad8_collate_tensor_fn( batch, *, collate_fn_map: Optional[Dict[Union[Type, Tuple[Type, ...]], Callable]] = None, ): max_len = max([elem.shape[0] for elem in batch]) if max_len % 8 == 0: return pad_collate_tensor_fn(batch) elem = batch[0] batch.append( torch.zeros( (max_len + 8 - (max_len % 8), *elem.shape[1:]), dtype=batch[0].dtype ) ) return pad_collate_tensor_fn(batch)[:-1] pad8_collate_fn_map = copy.deepcopy(default_collate_fn_map) pad8_collate_fn_map[torch.Tensor] = pad8_collate_tensor_fn def pad8_collate_object_fn( batch, *, collate_fn_map: Optional[Dict[Union[Type, Tuple[Type, ...]], Callable]] = None, ): return _collate([e.obj for e in batch], collate_fn_map=pad8_collate_fn_map) # pad2d Padded2dObject = namedtuple("Padded2dObject", ["obj"])
[docs] def pad2d(obj): """ Args: obj: Can be tensors, numpy arrays, lists, tuples, or dictionaries. """ return Padded2dObject(obj)
[docs] def track_mask2d(input: Union[torch.Tensor, np.ndarray]): r"""Wrap an array or tensor to specify that its padding mask should be tracked. This is used in conjunction with :obj:`pad2d`. Args: input: An array or tensor. """ if input.ndim != 2: raise ValueError( f"Expected input to have 2 dimensions, but got {input.ndim} dimensions." ) return pad2d(torch.ones(input.shape, dtype=torch.bool))
def pad2d_collate_tensor_fn( batch, *, collate_fn_map: Optional[Dict[Union[Type, Tuple[Type, ...]], Callable]] = None, ): if any(elem.ndim < 2 for elem in batch): raise ValueError("All tensors must have at least 2 dimensions.") max_n = max([elem.shape[0] for elem in batch]) max_m = max([elem.shape[1] for elem in batch]) elem = batch[0] b = torch.zeros((len(batch), max_n, max_m, *elem.shape[2:]), dtype=elem.dtype) for i, elem in enumerate(batch): b[i, : elem.shape[0], : elem.shape[1]] = elem return b pad2d_collate_fn_map = copy.deepcopy(default_collate_fn_map) pad2d_collate_fn_map[torch.Tensor] = pad2d_collate_tensor_fn def pad2d_collate_object_fn( batch, *, collate_fn_map: Optional[Dict[Union[Type, Tuple[Type, ...]], Callable]] = None, ): return _collate([e.obj for e in batch], collate_fn_map=pad2d_collate_fn_map) # pad2d8 Padded2d8Object = namedtuple("Padded2d8Object", ["obj"])
[docs] def pad2d8(obj): """ Args: obj: Can be tensors, numpy arrays, lists, tuples, or dictionaries. """ return Padded2d8Object(obj)
[docs] def track_mask2d8(input: Union[torch.Tensor, np.ndarray]): r"""Wrap an array or tensor to specify that its padding mask should be tracked. This is used in conjunction with :obj:`pad2d8`. Args: input: An array or tensor. """ if input.ndim != 2: raise ValueError( f"Expected input to have 2 dimensions, but got {input.ndim} dimensions." ) return pad2d8(torch.ones(input.shape, dtype=torch.bool))
def pad2d8_collate_tensor_fn( batch, *, collate_fn_map: Optional[Dict[Union[Type, Tuple[Type, ...]], Callable]] = None, ): if any(elem.ndim < 2 for elem in batch): raise ValueError("All tensors must have at least 2 dimensions.") max_n = max([elem.shape[0] for elem in batch]) max_m = max([elem.shape[1] for elem in batch]) # round inner dim up to multiple of 8 max_m = max_m + (8 - max_m % 8) % 8 elem = batch[0] b = torch.zeros((len(batch), max_n, max_m, *elem.shape[2:]), dtype=elem.dtype) for i, elem in enumerate(batch): b[i, : elem.shape[0], : elem.shape[1]] = elem return b pad2d8_collate_fn_map = copy.deepcopy(default_collate_fn_map) pad2d8_collate_fn_map[torch.Tensor] = pad2d8_collate_tensor_fn def pad2d8_collate_object_fn( batch, *, collate_fn_map: Optional[Dict[Union[Type, Tuple[Type, ...]], Callable]] = None, ): return _collate([e.obj for e in batch], collate_fn_map=pad2d8_collate_fn_map) # chain ChainObject = namedtuple("ChainObject", ["obj", "allow_missing_keys"]) ChainBatchTrackerObject = namedtuple("ChainBatchTrackerObject", ["obj"])
[docs] def chain(obj, allow_missing_keys: bool = False): r"""Wrap an object to specify that it (or any of its members) should be stacked along the first dimension when batching. This approach is similar to PyTorch Geometric's collate approach for graphs. This function will chain all the sequences in the batch into one large sequence. To track which sample from the batch an element of the sequence came from use :func:`track_batch`. Args: obj: Can be tensors, numpy arrays, lists, tuples, or dictionaries. allow_missing_keys: If set to :obj:`True`, this allows the wrapped dictionary to have inconsistent keys across the batch. If an instance is missing keys, the corresponding values will be skipped when collating. Defaults to :obj:`False`. """ if allow_missing_keys and not isinstance(obj, dict): raise TypeError( f"allow_missing_keys can only be used with dictionaries, got {type(obj)}." ) return ChainObject(obj, allow_missing_keys)
[docs] def track_batch(input: Union[torch.Tensor, np.ndarray]): r"""Wrap an array or tensor to track the batch_index. Args: input: An array or tensor. """ return ChainBatchTrackerObject(torch.ones((len(input)), dtype=torch.long))
def chain_collate_str_fn( batch, *, collate_fn_map: Optional[Dict[Union[Type, Tuple[Type, ...]], Callable]] = None, ): # No-op. return batch def chain_collate_tensor_fn( batch, *, collate_fn_map: Optional[Dict[Union[Type, Tuple[Type, ...]], Callable]] = None, ): return torch.cat(batch, dim=0) def chain_batch_tracker_collate_tensor_fn( batch, *, collate_fn_map: Optional[Dict[Union[Type, Tuple[Type, ...]], Callable]] = None, ): return _collate( [i * e.obj for i, e in enumerate(batch)], collate_fn_map=chain_collate_fn_map, ) chain_collate_fn_map = copy.deepcopy(default_collate_fn_map) chain_collate_fn_map[str] = chain_collate_str_fn chain_collate_fn_map[torch.Tensor] = chain_collate_tensor_fn chain_collate_fn_map[ChainBatchTrackerObject] = chain_batch_tracker_collate_tensor_fn def chain_collate_object_fn( batch, *, collate_fn_map: Optional[Dict[Union[Type, Tuple[Type, ...]], Callable]] = None, ): allow_missing_keys = batch[0].allow_missing_keys # check that flag is consistent if any((elem.allow_missing_keys != allow_missing_keys) for elem in batch): raise ValueError( "attribute 'allow_missing_keys' must be the same for all elements in the " "batch. Some elements allow missing keys while others do not." ) if not allow_missing_keys: return _collate( [elem.obj for elem in batch], collate_fn_map=chain_collate_fn_map ) else: unique_keys = set().union(*[elem.obj.keys() for elem in batch]) return { key: _collate( [elem.obj[key] for elem in batch if key in elem.obj], collate_fn_map=chain_collate_fn_map, ) for key in unique_keys } # add all new types to collate fn map # note that once a recipe is selected for a given object it cannot be overwritten # this is to avoid the following example scenario where pad(obj) is called but obj is a # dict and one of its values was already wrapped in chain. collate_fn_map = copy.deepcopy(default_collate_fn_map) collate_fn_map[PaddedObject] = pad_collate_object_fn collate_fn_map[Padded8Object] = pad8_collate_object_fn collate_fn_map[Padded2dObject] = pad2d_collate_object_fn collate_fn_map[Padded2d8Object] = pad2d8_collate_object_fn collate_fn_map[ChainObject] = chain_collate_object_fn collate_fn_map[ChainBatchTrackerObject] = chain_batch_tracker_collate_tensor_fn
[docs] def collate(batch): r"""Extension of PyTorch's :obj:`default_collate` function to enable more advanced collation of samples of variable lengths. To specify the collation recipe, wrap the objects using :obj:`pad` or :obj:`chain`. If the wrapped object is an Iterable or Mapping, all its elements will inherit the collation recipe. All objects that are already supported by :obj:`default_collate` can be wrapped. If an object is not wrapped, the default collation recipe will be used, i.e., the outcome will be identical to :obj:`default_collate`. :obj:`pad` or :obj:`chain` do not track any padding masks or batch index, since that might not always be needed. Use :obj:`track_mask` or :obj:`track_batch` to track masks or batch index for a particular array or tensor. Args: batch: a single batch to be collated """ return _collate(batch, collate_fn_map=collate_fn_map)
__all__ = [ "collate", "chain", "pad", "pad8", "pad2d", "pad2d8", "track_batch", "track_mask", "track_mask8", "track_mask2d", "track_mask2d8", ]