RotarySelfAttention#

class torch_brain.nn.RotarySelfAttention(*, dim, heads=8, dim_head=64, dropout=0.0, rotate_value=False, use_xformers=True)[source]#

Bases: torch.nn.modules.module.Module

Self-attention layer with rotary positional embeddings.

This layer performs self-attention within a sequence, with rotary positional embeddings applied to the queries and keys (and optionally values). It first normalizes the input, projects it to query/key/value space, applies rotary embeddings and attention, then projects back to the original dimension.

The layer provides two forward methods:

  • forward(): This is the default, and is used for sequences in a batch that are of the same length, or are padded to the same length. When padding is used, attention masks need to be provided.

  • forward_varlen(): Uses sequence lengths instead of masks for sequences that are chained together in a single batch dimension. This can be more memory efficient since it avoids padding, but requires the sequences to be concatenated rather than stacked.

Parameters:
  • dim (int) – Dimension of input embeddings

  • heads (int) – Number of attention heads

  • dim_head (int) – Dimension of each attention head

  • dropout (float) – Dropout probability

  • rotate_value (bool) – Whether to apply rotary embeddings to values as well as queries/keys

  • use_xformers (bool) – Whether to use xformers for attention. Defaults to True.

forward(x, rotary_time_emb, x_mask=None)[source]#

Forward pass for fixed-length sequences.

Shape:
  • x: (B, N, D)

  • rotary_time_emb: (B, N, D_h)

  • x_mask: (B, N, N)

  • Output: (B, N, D)

where B is batch size, N is sequence length, D is input dimension, and D_h is head dimension.

forward_varlen(x, rotary_time_emb, x_seqlen)[source]#

Forward pass for variable-length sequences.

Shape:
  • x: (N_total, D)

  • rotary_time_emb: (N_total, D_h)

  • x_seqlen: (B,)

  • Output: (N_total, D)

where N_total is the total sequence length across the batch, B is batch size, D is input dimension, and D_h is head dimension.