Transformer modules¶
Feed-forward network with GEGLU activation. |
|
Rotary cross-attention layer. |
|
Rotary self-attention layer. |
- class FeedForward(dim, mult=4, dropout=0.2)[source]¶
A feed-forward network with GEGLU activation.
- Parameters:
- forward(x)[source]¶
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class RotaryCrossAttention(*, dim, context_dim=None, heads=8, dim_head=64, dropout=0.0, rotate_value=False)[source]¶
Cross-attention layer with rotary positional embeddings.
This layer performs cross-attention between a query sequence and a context sequence, with rotary positional embeddings applied to the queries and keys (and optionally values). It first normalizes the inputs, projects them to query/key/value space, applies rotary embeddings and attention, then projects back to the original dimension.
The layer provides two forward methods:
forward(): This is the default, and is used for sequences in a batch that are of the same length, or are padded to the same length. When padding is used, attention masks need to be provided.
forward_varlen(): Uses sequence lengths instead of masks for sequences that are chained together in a single batch dimension. This can be more memory efficient since it avoids padding, but requires the sequences to be concatenated rather than stacked.
- Parameters:
dim (int) – Dimension of input query embeddings
context_dim (Optional[int]) – Dimension of input context embeddings. If None, uses same as dim
heads (int) – Number of attention heads
dim_head (int) – Dimension of each attention head
dropout (float) – Dropout probability
rotate_value (bool) – Whether to apply rotary embeddings to values as well as queries/keys
- forward(x_query, x_context, query_pos_emb, context_pos_emb, context_mask=None)[source]¶
Forward pass for regular or padded sequences.
- Shape:
x_query: (B, N_q, D_q)
x_context: (B, N_c, D_c)
query_pos_emb: (B, N_q, D_h)
context_pos_emb: (B, N_c, D_h)
context_mask: Optional[Tensor[B, N_c]]
Output: (B, N_q, D)
where B is batch size, N_q is query sequence length, N_c is context sequence length, D_q is input dimension, D_c is context dimension, H is number of heads, and D_h is head dimension.
- forward_varlen(x_query, x_context, query_pos_emb, context_pos_emb, query_seqlen, context_seqlen)[source]¶
Forward pass for variable length sequences.
Similar to forward() but handles variable length sequences that have been chained together in the batch dimension rather than being stacked and padded. This approach can be more memory efficient since it avoids padding, but requires the sequences to be concatenated rather than stacked.
- Shape:
x_query: (N_q_total, D)
x_context: (N_c_total, D_c)
query_pos_emb: (N_q_total, D_h)
context_pos_emb: (N_c_total, D_h)
query_seqlen: (B,)
context_seqlen: (B,)
Output: (N_q_total, D)
where N_q_total and N_c_total are the total sequence lengths across the batch, B is batch size, D is input dimension, D_c is context dimension, H is number of heads, and D_h is head dimension.
- class RotarySelfAttention(*, dim, heads=8, dim_head=64, dropout=0.0, rotate_value=False)[source]¶
Self-attention layer with rotary positional embeddings.
This layer performs self-attention within a sequence, with rotary positional embeddings applied to the queries and keys (and optionally values). It first normalizes the input, projects it to query/key/value space, applies rotary embeddings and attention, then projects back to the original dimension.
The layer provides two forward methods: - forward(): This is the default, and is used for sequences in a batch that are of
the same length, or are padded to the same length. When padding is used, attention masks need to be provided.
forward_varlen(): Uses sequence lengths instead of masks for sequences that are chained together in a single batch dimension. This can be more memory efficient since it avoids padding, but requires the sequences to be concatenated rather than stacked.
- Parameters:
- forward(x, rotary_time_emb, x_mask=None)[source]¶
Forward pass for fixed-length sequences.
- Shape:
x: (B, N, D)
rotary_time_emb: (B, N, D_h)
x_mask: (B, N, N)
Output: (B, N, D)
where B is batch size, N is sequence length, D is input dimension, and D_h is head dimension.
- forward_varlen(x, rotary_time_emb, x_seqlen)[source]¶
Forward pass for variable-length sequences.
- Shape:
x: (N_total, D)
rotary_time_emb: (N_total, D_h)
x_seqlen: (B,)
Output: (N_total, D)
where N_total is the total sequence length across the batch, B is batch size, D is input dimension, and D_h is head dimension.