Source code for torch_brain.nn.embedding
import torch
import torch.nn as nn
[docs]
class Embedding(nn.Embedding):
r"""A simple extension of :class:`torch.nn.Embedding` to allow more control over
the weights initializer. The learnable weights of the module of shape
`(num_embeddings, embedding_dim)` are initialized from
:math:`\mathcal{N}(0, \text{init_scale})`.
Args:
num_embeddings (int): size of the dictionary of embeddings
embedding_dim (int): the size of each embedding vector
init_scale (float, optional): standard deviation of the normal distribution used
for the initialization. Defaults to 0.02, which is the default value used in
most transformer models
**kwargs: Additional arguments. Refer to the documentation of
:class:`torch.nn.Embedding` for details
"""
def __init__(
self,
num_embeddings: int,
embedding_dim: int,
init_scale: float = 0.02,
**kwargs,
):
self.init_scale = init_scale
super().__init__(num_embeddings, embedding_dim, **kwargs)
[docs]
def reset_parameters(self):
r"""Resets all learnable parameters of the module."""
torch.nn.init.normal_(self.weight, mean=0, std=self.init_scale)
self._fill_padding_idx_with_zero()