Embedding#
- class torch_brain.nn.Embedding(num_embeddings, embedding_dim, init_scale=0.02, **kwargs)[source]#
Bases:
torch.nn.modules.sparse.EmbeddingA simple extension of
torch.nn.Embeddingto allow more control over the weights initializer. The learnable weights of the module of shape (num_embeddings, embedding_dim) are initialized from \(\mathcal{N}(0, \text{init_scale})\).- Parameters:
num_embeddings (int) – size of the dictionary of embeddings
embedding_dim (int) – the size of each embedding vector
init_scale (float, optional) – standard deviation of the normal distribution used for the initialization. Defaults to 0.02, which is the default value used in most transformer models
**kwargs – Additional arguments. Refer to the documentation of
torch.nn.Embeddingfor details