RotarySelfAttention#
- class torch_brain.nn.RotarySelfAttention(*, dim, heads=8, dim_head=64, dropout=0.0, rotate_value=False, use_xformers=True)[source]#
Bases:
torch.nn.modules.module.ModuleSelf-attention layer with rotary positional embeddings.
This layer performs self-attention within a sequence, with rotary positional embeddings applied to the queries and keys (and optionally values). It first normalizes the input, projects it to query/key/value space, applies rotary embeddings and attention, then projects back to the original dimension.
The layer provides two forward methods:
forward(): This is the default, and is used for sequences in a batch that are of the same length, or are padded to the same length. When padding is used, attention masks need to be provided.
forward_varlen(): Uses sequence lengths instead of masks for sequences that are chained together in a single batch dimension. This can be more memory efficient since it avoids padding, but requires the sequences to be concatenated rather than stacked.
- Parameters:
dim (int) – Dimension of input embeddings
heads (int) – Number of attention heads
dim_head (int) – Dimension of each attention head
dropout (float) – Dropout probability
rotate_value (bool) – Whether to apply rotary embeddings to values as well as queries/keys
use_xformers (bool) – Whether to use xformers for attention. Defaults to True.
- forward(x, rotary_time_emb, x_mask=None)[source]#
Forward pass for fixed-length sequences.
- Shape:
x: (B, N, D)
rotary_time_emb: (B, N, D_h)
x_mask: (B, N, N)
Output: (B, N, D)
where B is batch size, N is sequence length, D is input dimension, and D_h is head dimension.
- forward_varlen(x, rotary_time_emb, x_seqlen)[source]#
Forward pass for variable-length sequences.
- Shape:
x: (N_total, D)
rotary_time_emb: (N_total, D_h)
x_seqlen: (B,)
Output: (N_total, D)
where N_total is the total sequence length across the batch, B is batch size, D is input dimension, and D_h is head dimension.