RotaryTimeEmbedding#
- class torch_brain.nn.RotaryTimeEmbedding(head_dim, rotate_dim, t_min, t_max)[source]#
Bases:
torch.nn.modules.module.ModuleRotary time/positional embedding layer. This module is designed to be used with
torch_brain.nn.RotarySelfAttentionandtorch_brain.nn.RotaryCrossAttentionto modulate the attention weights in accordance with relative timing/positions of the tokens. Original paper: RoFormer: Enhanced Transformer with Rotary Position EmbeddingThe timeperiods are computed using
generate_logspace_timeperiods().- Parameters:
head_dim (int) – Dimension of the attention head.
rotate_dim (int) – Number of dimensions to rotate. You can choose to rotate only a small portion of the head dimension using this parameter. E.g. PerceiverIO found rotating only half dimensions to be effective.
t_min (float) – Minimum period of the sinusoids. Set this to the smallest timescale the attention layer should care about.
t_max (float) – Maximum period of the sinusoids. Set this to the largest timescale the attention layer should care about.
- forward(timestamps)[source]#
Computes the rotary embeddings for given timestamps, which can then be used by
RotaryTimeEmbedding.rotate().- Parameters:
timestamps (torch.Tensor) – timestamps tensor.
- Return type:
- static rotate(x, rotary_emb, unsqueeze_dim=2)[source]#
Apply the rotary positional embedding to the input data.
- Parameters:
x (torch.Tensor) – Input data.
rotary_emb (torch.Tensor) – The rotary embedding produced by a forward call of
RotaryTimeEmbedding.unsqueeze_dim (int, optional) – Dimension where heads are located in the input tensor. E.g. For input shape (batch, heads, seq_len, dim) use 1. For input shape (batch, seq_len, heads, dim) use 2. Defaults to 2.
- Return type:
- static invert(rotary_emb)[source]#
Invert/Negate rotary embedding. If the input embeddings correspond to a time \(t\), then the output embeddings correspond to time \(-t\).
- Parameters:
rotary_emb (torch.Tensor) – Embeddings produced by a forward call of
RotaryTimeEmbedding.- Return type: