Source code for torch_brain.nn.rotary_attention

from typing import Optional

import torch
import torch.nn.functional as F
import torch.nn as nn
from einops import rearrange, repeat

try:
    import xformers.ops as xops
except ImportError:
    xops = None


from torch_brain.nn import RotaryTimeEmbedding


[docs] class RotaryCrossAttention(nn.Module): """Cross-attention layer with rotary positional embeddings. This layer performs cross-attention between a query sequence and a context sequence, with rotary positional embeddings applied to the queries and keys (and optionally values). It first normalizes the inputs, projects them to query/key/value space, applies rotary embeddings and attention, then projects back to the original dimension. The layer provides two forward methods: - forward(): This is the default, and is used for sequences in a batch that are of the same length, or are padded to the same length. When padding is used, attention masks need to be provided. - forward_varlen(): Uses sequence lengths instead of masks for sequences that are chained together in a single batch dimension. This can be more memory efficient since it avoids padding, but requires the sequences to be concatenated rather than stacked. Args: dim (int): Dimension of input query embeddings context_dim (Optional[int]): Dimension of input context embeddings. If None, uses same as dim heads (int): Number of attention heads dim_head (int): Dimension of each attention head dropout (float): Dropout probability rotate_value (bool): Whether to apply rotary embeddings to values as well as queries/keys """ def __init__( self, *, dim: int, context_dim: Optional[int] = None, heads: int = 8, dim_head: int = 64, dropout: float = 0.0, rotate_value: bool = False, ): super().__init__() inner_dim = dim_head * heads context_dim = context_dim or dim self.heads = heads self.dropout = dropout self.rotate_value = rotate_value self.norm = nn.LayerNorm(dim) self.norm_context = nn.LayerNorm(context_dim) self.to_q = nn.Linear(dim, inner_dim, bias=False) self.to_kv = nn.Linear(context_dim, inner_dim * 2, bias=False) self.to_out = nn.Linear(inner_dim, dim)
[docs] def forward( self, x_query, x_context, query_pos_emb, context_pos_emb, context_mask=None, ): """Forward pass for regular or padded sequences. Shape: - x_query: (B, N_q, D_q) - x_context: (B, N_c, D_c) - query_pos_emb: (B, N_q, D_h) - context_pos_emb: (B, N_c, D_h) - context_mask: Optional[Tensor[B, N_c]] - Output: (B, N_q, D) where B is batch size, N_q is query sequence length, N_c is context sequence length, D_q is input dimension, D_c is context dimension, H is number of heads, and D_h is head dimension. """ # normalize x_query = self.norm(x_query) x_context = self.norm_context(x_context) # project to q, k, v q = self.to_q(x_query) k, v = self.to_kv(x_context).chunk(2, dim=-1) # select attention kernel if xops is not None and x_query.device.type == "cuda": # if xformers is available, use it for attention. # xformers supports attention masks when using the memory efficient attention # kernel, but pytorch does not. rotary_attn_func = rotary_attn_xformers_func else: # otherwise use pytorch's default attention which will determine the best # attention kernel (math, mem_efficient or flash) based on the hardware and # other factors. rotary_attn_func = rotary_attn_pytorch_func # apply attention out = rotary_attn_func( query=q, key=k, value=v, q_pos_emb=query_pos_emb, kv_pos_emb=context_pos_emb, num_heads=self.heads, dropout_p=self.dropout if self.training else 0, rotate_value=self.rotate_value, attn_mask=context_mask, ) # project back to dim out = self.to_out(out) return out
[docs] def forward_varlen( self, x_query, x_context, query_pos_emb, context_pos_emb, query_seqlen, context_seqlen, ): """Forward pass for variable length sequences. Similar to forward() but handles variable length sequences that have been chained together in the batch dimension rather than being stacked and padded. This approach can be more memory efficient since it avoids padding, but requires the sequences to be concatenated rather than stacked. Shape: - x_query: (N_q_total, D) - x_context: (N_c_total, D_c) - query_pos_emb: (N_q_total, D_h) - context_pos_emb: (N_c_total, D_h) - query_seqlen: (B,) - context_seqlen: (B,) - Output: (N_q_total, D) where N_q_total and N_c_total are the total sequence lengths across the batch, B is batch size, D is input dimension, D_c is context dimension, H is number of heads, and D_h is head dimension. """ # normalize x_query = self.norm(x_query) x_context = self.norm_context(x_context) # project to q, k, v q = self.to_q(x_query) k, v = self.to_kv(x_context).chunk(2, dim=-1) # select attention kernel if xops is not None and x_query.device.type == "cuda": rotary_attn_func = rotary_attn_xformers_varlen_func else: if x_query.device.type == "cuda": raise RuntimeError( "No varlen attention kernel available, please install xformers." ) else: # forward_varlen is not implemented for CPU, forward should be used instead raise NotImplementedError( "No varlen attention kernel available for CPU." ) # apply attention out = rotary_attn_func( query=q, key=k, value=v, q_pos_emb=query_pos_emb, kv_pos_emb=context_pos_emb, num_heads=self.heads, dropout_p=self.dropout if self.training else 0, rotate_value=self.rotate_value, q_seqlen=query_seqlen, kv_seqlen=context_seqlen, ) # project back to dim out = self.to_out(out) return out
[docs] class RotarySelfAttention(nn.Module): """Self-attention layer with rotary positional embeddings. This layer performs self-attention within a sequence, with rotary positional embeddings applied to the queries and keys (and optionally values). It first normalizes the input, projects it to query/key/value space, applies rotary embeddings and attention, then projects back to the original dimension. The layer provides two forward methods: - forward(): This is the default, and is used for sequences in a batch that are of the same length, or are padded to the same length. When padding is used, attention masks need to be provided. - forward_varlen(): Uses sequence lengths instead of masks for sequences that are chained together in a single batch dimension. This can be more memory efficient since it avoids padding, but requires the sequences to be concatenated rather than stacked. Args: dim (int): Dimension of input embeddings heads (int): Number of attention heads dim_head (int): Dimension of each attention head dropout (float): Dropout probability rotate_value (bool): Whether to apply rotary embeddings to values as well as queries/keys """ def __init__( self, *, dim: int, heads: int = 8, dim_head: int = 64, dropout: float = 0.0, rotate_value: bool = False, ): super().__init__() inner_dim = dim_head * heads self.heads = heads self.dropout = dropout self.rotate_value = rotate_value self.norm = nn.LayerNorm(dim) self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False) self.to_out = nn.Linear(inner_dim, dim)
[docs] def forward( self, x, rotary_time_emb, x_mask=None, ): """Forward pass for fixed-length sequences. Shape: - x: (B, N, D) - rotary_time_emb: (B, N, D_h) - x_mask: (B, N, N) - Output: (B, N, D) where B is batch size, N is sequence length, D is input dimension, and D_h is head dimension. """ # normalize x = self.norm(x) # project to q, k, v q, k, v = self.to_qkv(x).chunk(3, dim=-1) # select attention kernel if xops is not None and x.device.type == "cuda": rotary_attn_func = rotary_attn_xformers_func else: rotary_attn_func = rotary_attn_pytorch_func # apply attention out = rotary_attn_func( query=q, key=k, value=v, q_pos_emb=rotary_time_emb, kv_pos_emb=rotary_time_emb, num_heads=self.heads, dropout_p=self.dropout if self.training else 0, rotate_value=self.rotate_value, attn_mask=x_mask, ) # project back to dim out = self.to_out(out) return out
[docs] def forward_varlen( self, x, rotary_time_emb, x_seqlen, ): """Forward pass for variable-length sequences. Shape: - x: (N_total, D) - rotary_time_emb: (N_total, D_h) - x_seqlen: (B,) - Output: (N_total, D) where N_total is the total sequence length across the batch, B is batch size, D is input dimension, and D_h is head dimension. """ # normalize x = self.norm(x) # project to q, k, v q, k, v = self.to_qkv(x).chunk(3, dim=-1) # select attention kernel if xops is not None and x.device.type == "cuda": rotary_attn_func = rotary_attn_xformers_varlen_func else: if x.device.type == "cuda": raise RuntimeError( "No varlen attention kernel available, please install xformers." ) else: # forward_varlen is not implemented for CPU, forward should be used instead raise NotImplementedError( "No varlen attention kernel available for CPU." ) # apply attention out = rotary_attn_func( query=q, key=k, value=v, q_pos_emb=rotary_time_emb, kv_pos_emb=rotary_time_emb, num_heads=self.heads, dropout_p=self.dropout if self.training else 0, rotate_value=self.rotate_value, q_seqlen=x_seqlen, kv_seqlen=None, # self-attention has the same seqlen for q, k, v ) # project back to dim out = self.to_out(out) return out
def rotary_attn_pytorch_func( *, query, key, value, q_pos_emb, kv_pos_emb, attn_mask=None, num_heads: int, dropout_p: float, rotate_value: bool, ): # uses the default scaled dot product attention from pytorch # https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html # this implements basic versions of memory efficient attention and flash attention # but more advanced versions are available in xformers and flash_attn (varlen) # which allow us to perform complex masking operations r"""Wraps the default attention implementation with rotary embedding application. Args: query: The query tensor, with shape (b, n_q, (h d)) key: The key tensor, with shape (b, n_kv, (h d)) value: The value tensor, with shape (b, n_kv, (h d)) q_pos_emb: The query rotary position embedding, with shape (b, n_q, d) kv_pos_emb: The key rotary position embedding, with shape (b, n_kv, d) num_heads: The number of attention heads dropout_p: The dropout probability rotate_value: Whether to rotate the value in addition to the query and key attn_mask: The attention mask, with shape (b, n_kv) Returns: The output tensor, with shape (b, n_q, (h d)) """ # default attention expects shape b h n d query = rearrange(query, "b n (h d) -> b h n d", h=num_heads) key = rearrange(key, "b n (h d) -> b h n d", h=num_heads) value = rearrange(value, "b n (h d) -> b h n d", h=num_heads) # apply rotary embeddings query = RotaryTimeEmbedding.rotate(x=query, rotary_emb=q_pos_emb, unsqueeze_dim=1) key = RotaryTimeEmbedding.rotate(x=key, rotary_emb=kv_pos_emb, unsqueeze_dim=1) if rotate_value: value = RotaryTimeEmbedding.rotate( x=value, rotary_emb=kv_pos_emb, unsqueeze_dim=1 ) # attention mask if attn_mask is not None: attn_mask = rearrange(attn_mask, "b n -> b () () n") # perform attention, by default will use the optimal attention implementation out = F.scaled_dot_product_attention( query, key, value, attn_mask=attn_mask, dropout_p=dropout_p, ) if rotate_value: out = RotaryTimeEmbedding.rotate( x=out, rotary_emb=RotaryTimeEmbedding.invert(q_pos_emb), unsqueeze_dim=1, ) # return (b, n, (h d), ) out = rearrange(out, "b h n d -> b n (h d)") return out def rotary_attn_xformers_func( *, query, key, value, q_pos_emb, kv_pos_emb, attn_mask=None, num_heads: int, dropout_p: float, rotate_value: bool, ): r"""Wraps the memory efficient attention implementation with rotary embedding application. Args: query: The query tensor, with shape (b n (h d)) key: The key tensor, with shape (b n (h d)) value: The value tensor, with shape (b n (h d)) q_pos_emb: The query rotary position embedding, with shape (b n d) kv_pos_emb: The key rotary position embedding, with shape (b n d) attn_mask: The attention mask, with shape (b, n_kv). A value of True indicates that the element should take part in attention. num_heads: The number of attention heads dropout_p: The dropout probability rotate_value: Whether to rotate the value in addition to the query and key Returns: The output tensor, with shape (b n (h d)) """ # xformers attention expects shape (1, n, h, d) query = rearrange(query, "b n (h d) -> b n h d", h=num_heads) key = rearrange(key, "b n (h d) -> b n h d", h=num_heads) value = rearrange(value, "b n (h d) -> b n h d", h=num_heads) query = RotaryTimeEmbedding.rotate(x=query, rotary_emb=q_pos_emb, unsqueeze_dim=2) key = RotaryTimeEmbedding.rotate(x=key, rotary_emb=kv_pos_emb, unsqueeze_dim=2) if rotate_value: value = RotaryTimeEmbedding.rotate( x=value, rotary_emb=kv_pos_emb, unsqueeze_dim=2 ) # WARNING: this is very slow, avoid using attn_mask if possible, refer to xformers # documentation attn_mask = ( repeat(attn_mask, "b m -> b h n m", h=num_heads, n=query.size(1)) if attn_mask is not None else None ) attn_bias = ( attn_mask.to(query.dtype).masked_fill(attn_mask.logical_not(), float("-inf")) if attn_mask is not None else None ) out = xops.memory_efficient_attention( query, key, value, attn_bias=attn_bias, p=dropout_p, ) if rotate_value: out = RotaryTimeEmbedding.rotate( x=out, rotary_emb=RotaryTimeEmbedding.invert(q_pos_emb), unsqueeze_dim=2, ) out = rearrange(out, "b n h d -> b n (h d)") return out def rotary_attn_xformers_varlen_func( *, query, key, value, q_pos_emb, kv_pos_emb, q_seqlen, kv_seqlen, num_heads: int, dropout_p: float, rotate_value: bool, ): r"""Wraps the memory efficient attention implementation with rotary embedding application. Args: query: The query tensor, with shape (n, (h d)) key: The key tensor, with shape (n, (h d)) value: The value tensor, with shape (n, (h d)) query_pos_emb: The query rotary position embedding, with shape (n, d) key_pos_emb: The key rotary position embedding, with shape (n, d) num_heads: The number of attention heads dropout_p: The dropout probability rotate_value: Whether to rotate the value in addition to the query and key q_seqlen: The sequence length of the query tensor kv_seqlen: The sequence length of the key and value tensors Returns: The output tensor, with shape (n, (h d)) """ # xformers attention expects shape (1, n, h, d) query = rearrange(query, "n (h d) -> () n h d", h=num_heads) key = rearrange(key, "n (h d) -> () n h d", h=num_heads) value = rearrange(value, "n (h d) -> () n h d", h=num_heads) # TODO check rotation works query = RotaryTimeEmbedding.rotate(x=query, rotary_emb=q_pos_emb.unsqueeze(0)) key = RotaryTimeEmbedding.rotate(x=key, rotary_emb=kv_pos_emb.unsqueeze(0)) if rotate_value: value = RotaryTimeEmbedding.rotate(x=value, rotary_emb=kv_pos_emb.unsqueeze(0)) if isinstance(q_seqlen, torch.Tensor): q_seqlen = q_seqlen.tolist() if isinstance(kv_seqlen, torch.Tensor): kv_seqlen = kv_seqlen.tolist() # fill attention_bias with BlockDiagonalMask with torch.no_grad(): attn_bias = xops.fmha.BlockDiagonalMask.from_seqlens( q_seqlen=q_seqlen, kv_seqlen=kv_seqlen, ) out = xops.memory_efficient_attention( query, key, value, attn_bias=attn_bias, p=dropout_p, ) if rotate_value: out = RotaryTimeEmbedding.rotate( x=out, rotary_emb=RotaryTimeEmbedding.invert(q_pos_emb).unsqueeze(0), ) out = rearrange(out, "() n h d -> n (h d)") return out