RotaryCrossAttention#

class torch_brain.nn.RotaryCrossAttention(*, dim, context_dim=None, heads=8, dim_head=64, dropout=0.0, rotate_value=False, use_xformers=True)[source]#

Bases: torch.nn.modules.module.Module

Cross-attention layer with rotary positional embeddings.

This layer performs cross-attention between a query sequence and a context sequence, with rotary positional embeddings applied to the queries and keys (and optionally values). It first normalizes the inputs, projects them to query/key/value space, applies rotary embeddings and attention, then projects back to the original dimension.

The layer provides two forward methods:

  • forward(): This is the default, and is used for sequences in a batch that are of the same length, or are padded to the same length. When padding is used, attention masks need to be provided.

  • forward_varlen(): Uses sequence lengths instead of masks for sequences that are chained together in a single batch dimension. This can be more memory efficient since it avoids padding, but requires the sequences to be concatenated rather than stacked.

Parameters:
  • dim (int) – Dimension of input query embeddings

  • context_dim (Optional[int]) – Dimension of input context embeddings. If None, uses same as dim

  • heads (int) – Number of attention heads

  • dim_head (int) – Dimension of each attention head

  • dropout (float) – Dropout probability

  • rotate_value (bool) – Whether to apply rotary embeddings to values as well as queries/keys

  • use_xformers (bool) – Whether to use xformers for attention. Defaults to True.

forward(x_query, x_context, query_pos_emb, context_pos_emb, context_mask=None)[source]#

Forward pass for regular or padded sequences.

Shape:
  • x_query: (B, N_q, D_q)

  • x_context: (B, N_c, D_c)

  • query_pos_emb: (B, N_q, D_h)

  • context_pos_emb: (B, N_c, D_h)

  • context_mask: Optional[Tensor[B, N_c]]

  • Output: (B, N_q, D)

where B is batch size, N_q is query sequence length, N_c is context sequence length, D_q is input dimension, D_c is context dimension, H is number of heads, and D_h is head dimension.

forward_varlen(x_query, x_context, query_pos_emb, context_pos_emb, query_seqlen, context_seqlen)[source]#

Forward pass for variable length sequences.

Similar to forward() but handles variable length sequences that have been chained together in the batch dimension rather than being stacked and padded. This approach can be more memory efficient since it avoids padding, but requires the sequences to be concatenated rather than stacked.

Shape:
  • x_query: (N_q_total, D)

  • x_context: (N_c_total, D_c)

  • query_pos_emb: (N_q_total, D_h)

  • context_pos_emb: (N_c_total, D_h)

  • query_seqlen: (B,)

  • context_seqlen: (B,)

  • Output: (N_q_total, D)

where N_q_total and N_c_total are the total sequence lengths across the batch, B is batch size, D is input dimension, D_c is context dimension, H is number of heads, and D_h is head dimension.