File size: 2,256 Bytes
14ce5a9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 |
"""This file contains the definition of utility functions to group tokens."""
import math
import torch
def combine_factorized_tokens(
tokens: torch.Tensor, codebook_size: int, splits: int
) -> torch.Tensor:
"""
Combine the tokens into a single token.
Args:
tokens -> torch.Tensor: Tensor of shape (batch_size, n, m).
codebook_size -> int: The size of the codebook.
splits -> int: Number of splits.
Returns:
combined_tokens -> torch.Tensor: Tensor of shape (batch_size, n).
"""
combined_tokens = torch.zeros(
(tokens.shape[0], tokens.shape[1]), device=tokens.device
)
bit_shift = int(math.log2(codebook_size)) // splits
for i in range(splits):
combined_tokens += tokens[..., i] << (i * bit_shift)
return combined_tokens
def split_factorized_tokens(
tokens: torch.Tensor, codebook_size: int, splits: int
) -> torch.Tensor:
"""
Split the tokens into multiple tokens.
Args:
tokens -> torch.Tensor: Tensor of shape (batch_size, n).
codebook_size -> int: The size of the codebook.
splits -> int: Number of splits.
Returns:
split_tokens -> torch.Tensor: Tensor of shape (batch_size, n, m).
"""
bit_shift = int(math.log2(codebook_size)) // splits
bit_mask = (1 << bit_shift) - 1
split_tokens = []
for i in range(splits):
split_tokens.append((tokens & (bit_mask << (i * bit_shift))) >> (i * bit_shift))
return torch.stack(split_tokens, dim=2)
if __name__ == "__main__":
tokens = torch.randint(0, 1023, (1, 16))
split_tokens = split_factorized_tokens(tokens, 1024, 1)
assert split_tokens.shape == (1, 16, 1)
assert split_tokens.dtype == torch.int64
combined_tokens = combine_factorized_tokens(split_tokens, 1024, 1)
assert (tokens == combined_tokens).all()
split_tokens = split_factorized_tokens(tokens, 1024, 2)
combined_tokens = combine_factorized_tokens(split_tokens, 1024, 2)
assert split_tokens.shape == (1, 16, 2)
assert (tokens == combined_tokens).all(), f"{tokens} != {combined_tokens}"
assert (torch.bitwise_right_shift(tokens, 5) == split_tokens[..., 1]).all()
assert (tokens & 31 == split_tokens[..., 0]).all()
|