Source code for torch_brain.utils.tokenizers
import numpy as np
from einops import repeat
from enum import Enum
class TokenType(Enum):
DEFAULT = 0
START_OF_SEQUENCE = 1
END_OF_SEQUENCE = 2
[docs]
def create_start_end_unit_tokens(unit_ids, start, end):
r"""Creates for each unit a start and end token. Each token is defined by the
unit index, the token type index and the timestamps.
Args:
unit_ids (np.ndarray): List of unit identifiers.
start (float): The start time of the sequence.
end (float): The end time of the sequence.
"""
token_type_index = np.array(
[TokenType.START_OF_SEQUENCE.value, TokenType.END_OF_SEQUENCE.value],
dtype=np.int64,
)
token_type_index = repeat(token_type_index, "u -> (t u)", t=len(unit_ids))
unit_index = np.arange(len(unit_ids))
unit_index = repeat(unit_index, "u -> (u t)", t=2)
timestamps = np.array([start, end], dtype=np.float64)
timestamps = repeat(timestamps, "u -> (t u)", t=len(unit_ids))
return token_type_index, unit_index, timestamps
[docs]
def create_linspace_latent_tokens(start, end, step, num_latents_per_step):
r"""Creates a sequence of latent tokens. Each token is defined by the
latent index and the timestamps. The sequence is defined by the start and end
time and the step size. The group of `num_latents_per_step` latents is repeated
for each step.
Args:
start (float): The start time of the sequence.
end (float): The end time of the sequence.
step (float): The step size.
num_latents_per_step (int): The number of latents per step.
"""
sequence_len = end - start
latent_timestamps = np.arange(0, sequence_len, step) + step / 2 + start
latent_index = np.arange(num_latents_per_step, dtype=np.int64)
num_timestamps = len(latent_timestamps)
latent_timestamps = repeat(latent_timestamps, "t -> (t u)", u=len(latent_index))
latent_index = repeat(latent_index, "u -> (t u)", t=num_timestamps)
return latent_index, latent_timestamps