Source code for torch_brain.nn.position_embeddings

from typing import Union
import torch
from torch import nn, Tensor
from einops import repeat, rearrange


[docs] class SinusoidalTimeEmbedding(nn.Module): r"""Sinusoidal time/position embedding layer. These embeddings are generally added/concatenated to tokens to give them a sense of time/position. The timeperiods are logarithmically spaced between ``t_min`` and ``t_max`` (both inclusive). Args: dim (int): The dimension of the embedding needed (must be a multiple of 2) t_min (float): Minimum period of the sinusoids. Set this to the smallest timescale you care about. t_max (float): Maximum period of the sinusoids. Set this to the largest timescale you care about. """ omega: Tensor def __init__(self, dim: int, t_min: float, t_max: float): super().__init__() if dim % 2 != 0: raise ValueError("`dim` must be a multiple of 2") periods = generate_logspace_timeperiods(dim // 2, t_min, t_max) omega = 2 * torch.pi / periods self.register_buffer("omega", omega)
[docs] @torch.no_grad @torch.autocast(device_type="cuda", enabled=False) def forward(self, timestamps: Tensor) -> Tensor: r"""Convert raw timestamps to sinusoidal embeddings Args: timestamps (torch.Tensor): timestamps tensor """ angles = timestamps[..., None] * self.omega return torch.cat((angles.sin(), angles.cos()), dim=-1)
[docs] class RotaryTimeEmbedding(nn.Module): r"""Rotary time/positional embedding layer. This module is designed to be used with :class:`torch_brain.nn.RotarySelfAttention` and :class:`torch_brain.nn.RotaryCrossAttention` to modulate the attention weights in accordance with relative timing/positions of the tokens. Original paper: `RoFormer: Enhanced Transformer with Rotary Position Embedding <https://arxiv.org/abs/2104.09864>`_ The timeperiods are computed using :func:`generate_logspace_timeperiods`. Args: head_dim (int): Dimension of the attention head. rotate_dim (int): Number of dimensions to rotate. You can choose to rotate only a small portion of the head dimension using this parameter. E.g. `PerceiverIO <https://arxiv.org/abs/2107.14795>`_ found rotating only half dimensions to be effective. t_min (float): Minimum period of the sinusoids. Set this to the smallest timescale the attention layer should care about. t_max (float): Maximum period of the sinusoids. Set this to the largest timescale the attention layer should care about. """ omega: Tensor def __init__(self, head_dim: int, rotate_dim: int, t_min: float, t_max: float): super().__init__() if rotate_dim % 2 != 0: raise ValueError("rotate_dim must be a multiple of 2") if not head_dim >= rotate_dim: raise ValueError("head_dim must be equal to or larger than rotate_dim") periods = generate_logspace_timeperiods(rotate_dim // 2, t_min, t_max) omega = torch.zeros(head_dim // 2) omega[: rotate_dim // 2] = 2 * torch.pi / periods self.register_buffer("omega", omega)
[docs] @torch.no_grad @torch.autocast(device_type="cuda", enabled=False) def forward(self, timestamps: Tensor) -> Tensor: r"""Computes the rotary embeddings for given timestamps, which can then be used by :meth:`RotaryTimeEmbedding.rotate`. Args: timestamps (torch.Tensor): timestamps tensor. """ angles = torch.einsum("..., f -> ... f", timestamps, self.omega) angles = repeat(angles, "... n -> ... (n r)", r=2) rotary_emb = torch.cat((angles.cos(), angles.sin()), dim=-1) return rotary_emb
@staticmethod def _rotate_half(x: Tensor) -> Tensor: x = rearrange(x, "... (d r) -> ... d r", r=2) x1, x2 = x.unbind(dim=-1) x = torch.stack((-x2, x1), dim=-1) return rearrange(x, "... d r -> ... (d r)")
[docs] @staticmethod def rotate( x: Tensor, rotary_emb: Tensor, unsqueeze_dim: int = 2, ) -> Tensor: r"""Apply the rotary positional embedding to the input data. Args: x (torch.Tensor): Input data. rotary_emb (torch.Tensor): The rotary embedding produced by a forward call of :class:`RotaryTimeEmbedding`. unsqueeze_dim (int, optional): Dimension where heads are located in the input tensor. E.g. For input shape (batch, heads, seq_len, dim) use 1. For input shape (batch, seq_len, heads, dim) use 2. Defaults to 2. """ rotary_emb = rotary_emb.unsqueeze(unsqueeze_dim).to(x.dtype) cos, sin = rotary_emb.chunk(chunks=2, dim=-1) return (x * cos) + (RotaryTimeEmbedding._rotate_half(x) * sin)
[docs] @staticmethod def invert(rotary_emb: Tensor) -> Tensor: r"""Invert/Negate rotary embedding. If the input embeddings correspond to a time :math:`t`, then the output embeddings correspond to time :math:`-t`. Args: rotary_emb (torch.Tensor): Embeddings produced by a forward call of :class:`RotaryTimeEmbedding`. """ cos, sin = rotary_emb.chunk(chunks=2, dim=-1) return torch.cat((cos, -sin), dim=-1)
def generate_logspace_timeperiods( num: int, t_min: Union[float, Tensor], t_max: Union[float, Tensor], ) -> Tensor: r"""Generates ``num`` timeperiods that are logarithmically spaced between ``t_min`` and ``t_max`` (both inclusive). Args: num (int): number of timestamps needed t_min (float): smallest timeperiod t_max (float): largest timeperiod """ if not 0 < t_min < t_max: raise ValueError( f"Invalid t_min ({t_min}) and t_max ({t_max}). They should follow 0 < t_min < t_max." ) exponents = torch.linspace(0, 1.0, num, dtype=torch.float32) t_min, t_max = torch.tensor(t_min), torch.tensor(t_max) periods = torch.exp(torch.lerp(t_min.log(), t_max.log(), exponents)) assert torch.isclose(periods[0], t_min) assert torch.isclose(periods[-1], t_max) return periods