RotaryTimeEmbedding#

class torch_brain.nn.RotaryTimeEmbedding(head_dim, rotate_dim, t_min, t_max)[source]#

Bases: torch.nn.modules.module.Module

Rotary time/positional embedding layer. This module is designed to be used with torch_brain.nn.RotarySelfAttention and torch_brain.nn.RotaryCrossAttention to modulate the attention weights in accordance with relative timing/positions of the tokens. Original paper: RoFormer: Enhanced Transformer with Rotary Position Embedding

The timeperiods are computed using generate_logspace_timeperiods().

Parameters:
  • head_dim (int) – Dimension of the attention head.

  • rotate_dim (int) – Number of dimensions to rotate. You can choose to rotate only a small portion of the head dimension using this parameter. E.g. PerceiverIO found rotating only half dimensions to be effective.

  • t_min (float) – Minimum period of the sinusoids. Set this to the smallest timescale the attention layer should care about.

  • t_max (float) – Maximum period of the sinusoids. Set this to the largest timescale the attention layer should care about.

forward(timestamps)[source]#

Computes the rotary embeddings for given timestamps, which can then be used by RotaryTimeEmbedding.rotate().

Parameters:

timestamps (torch.Tensor) – timestamps tensor.

Return type:

Tensor

static rotate(x, rotary_emb, unsqueeze_dim=2)[source]#

Apply the rotary positional embedding to the input data.

Parameters:
  • x (torch.Tensor) – Input data.

  • rotary_emb (torch.Tensor) – The rotary embedding produced by a forward call of RotaryTimeEmbedding.

  • unsqueeze_dim (int, optional) – Dimension where heads are located in the input tensor. E.g. For input shape (batch, heads, seq_len, dim) use 1. For input shape (batch, seq_len, heads, dim) use 2. Defaults to 2.

Return type:

Tensor

static invert(rotary_emb)[source]#

Invert/Negate rotary embedding. If the input embeddings correspond to a time \(t\), then the output embeddings correspond to time \(-t\).

Parameters:

rotary_emb (torch.Tensor) – Embeddings produced by a forward call of RotaryTimeEmbedding.

Return type:

Tensor