VTBench / src /vqvaes /maskbit /modules /factorization.py
huaweilin's picture
update
14ce5a9
"""This file contains the definition of utility functions to group tokens."""
import math
import torch
def combine_factorized_tokens(
tokens: torch.Tensor, codebook_size: int, splits: int
) -> torch.Tensor:
"""
Combine the tokens into a single token.
Args:
tokens -> torch.Tensor: Tensor of shape (batch_size, n, m).
codebook_size -> int: The size of the codebook.
splits -> int: Number of splits.
Returns:
combined_tokens -> torch.Tensor: Tensor of shape (batch_size, n).
"""
combined_tokens = torch.zeros(
(tokens.shape[0], tokens.shape[1]), device=tokens.device
)
bit_shift = int(math.log2(codebook_size)) // splits
for i in range(splits):
combined_tokens += tokens[..., i] << (i * bit_shift)
return combined_tokens
def split_factorized_tokens(
tokens: torch.Tensor, codebook_size: int, splits: int
) -> torch.Tensor:
"""
Split the tokens into multiple tokens.
Args:
tokens -> torch.Tensor: Tensor of shape (batch_size, n).
codebook_size -> int: The size of the codebook.
splits -> int: Number of splits.
Returns:
split_tokens -> torch.Tensor: Tensor of shape (batch_size, n, m).
"""
bit_shift = int(math.log2(codebook_size)) // splits
bit_mask = (1 << bit_shift) - 1
split_tokens = []
for i in range(splits):
split_tokens.append((tokens & (bit_mask << (i * bit_shift))) >> (i * bit_shift))
return torch.stack(split_tokens, dim=2)
if __name__ == "__main__":
tokens = torch.randint(0, 1023, (1, 16))
split_tokens = split_factorized_tokens(tokens, 1024, 1)
assert split_tokens.shape == (1, 16, 1)
assert split_tokens.dtype == torch.int64
combined_tokens = combine_factorized_tokens(split_tokens, 1024, 1)
assert (tokens == combined_tokens).all()
split_tokens = split_factorized_tokens(tokens, 1024, 2)
combined_tokens = combine_factorized_tokens(split_tokens, 1024, 2)
assert split_tokens.shape == (1, 16, 2)
assert (tokens == combined_tokens).all(), f"{tokens} != {combined_tokens}"
assert (torch.bitwise_right_shift(tokens, 5) == split_tokens[..., 1]).all()
assert (tokens & 31 == split_tokens[..., 0]).all()