Embedding#

class torch_brain.nn.Embedding(num_embeddings, embedding_dim, init_scale=0.02, **kwargs)[source]#

Bases: torch.nn.modules.sparse.Embedding

A simple extension of torch.nn.Embedding to allow more control over the weights initializer. The learnable weights of the module of shape (num_embeddings, embedding_dim) are initialized from \(\mathcal{N}(0, \text{init_scale})\).

Parameters:
  • num_embeddings (int) – size of the dictionary of embeddings

  • embedding_dim (int) – the size of each embedding vector

  • init_scale (float, optional) – standard deviation of the normal distribution used for the initialization. Defaults to 0.02, which is the default value used in most transformer models

  • **kwargs – Additional arguments. Refer to the documentation of torch.nn.Embedding for details

reset_parameters()[source]#

Resets all learnable parameters of the module.