"""This file contains the definition of utility functions to group tokens.""" import math import torch def combine_factorized_tokens( tokens: torch.Tensor, codebook_size: int, splits: int ) -> torch.Tensor: """ Combine the tokens into a single token. Args: tokens -> torch.Tensor: Tensor of shape (batch_size, n, m). codebook_size -> int: The size of the codebook. splits -> int: Number of splits. Returns: combined_tokens -> torch.Tensor: Tensor of shape (batch_size, n). """ combined_tokens = torch.zeros( (tokens.shape[0], tokens.shape[1]), device=tokens.device ) bit_shift = int(math.log2(codebook_size)) // splits for i in range(splits): combined_tokens += tokens[..., i] << (i * bit_shift) return combined_tokens def split_factorized_tokens( tokens: torch.Tensor, codebook_size: int, splits: int ) -> torch.Tensor: """ Split the tokens into multiple tokens. Args: tokens -> torch.Tensor: Tensor of shape (batch_size, n). codebook_size -> int: The size of the codebook. splits -> int: Number of splits. Returns: split_tokens -> torch.Tensor: Tensor of shape (batch_size, n, m). """ bit_shift = int(math.log2(codebook_size)) // splits bit_mask = (1 << bit_shift) - 1 split_tokens = [] for i in range(splits): split_tokens.append((tokens & (bit_mask << (i * bit_shift))) >> (i * bit_shift)) return torch.stack(split_tokens, dim=2) if __name__ == "__main__": tokens = torch.randint(0, 1023, (1, 16)) split_tokens = split_factorized_tokens(tokens, 1024, 1) assert split_tokens.shape == (1, 16, 1) assert split_tokens.dtype == torch.int64 combined_tokens = combine_factorized_tokens(split_tokens, 1024, 1) assert (tokens == combined_tokens).all() split_tokens = split_factorized_tokens(tokens, 1024, 2) combined_tokens = combine_factorized_tokens(split_tokens, 1024, 2) assert split_tokens.shape == (1, 16, 2) assert (tokens == combined_tokens).all(), f"{tokens} != {combined_tokens}" assert (torch.bitwise_right_shift(tokens, 5) == split_tokens[..., 1]).all() assert (tokens & 31 == split_tokens[..., 0]).all()