File size: 2,256 Bytes
14ce5a9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
"""This file contains the definition of utility functions to group tokens."""

import math
import torch


def combine_factorized_tokens(
    tokens: torch.Tensor, codebook_size: int, splits: int
) -> torch.Tensor:
    """
    Combine the tokens into a single token.

    Args:
        tokens -> torch.Tensor: Tensor of shape (batch_size, n, m).
        codebook_size -> int: The size of the codebook.
        splits -> int: Number of splits.

    Returns:
        combined_tokens -> torch.Tensor: Tensor of shape (batch_size, n).
    """
    combined_tokens = torch.zeros(
        (tokens.shape[0], tokens.shape[1]), device=tokens.device
    )
    bit_shift = int(math.log2(codebook_size)) // splits
    for i in range(splits):
        combined_tokens += tokens[..., i] << (i * bit_shift)

    return combined_tokens


def split_factorized_tokens(
    tokens: torch.Tensor, codebook_size: int, splits: int
) -> torch.Tensor:
    """
    Split the tokens into multiple tokens.

    Args:
        tokens -> torch.Tensor: Tensor of shape (batch_size, n).
        codebook_size -> int: The size of the codebook.
        splits -> int: Number of splits.

    Returns:
        split_tokens -> torch.Tensor: Tensor of shape (batch_size, n, m).
    """
    bit_shift = int(math.log2(codebook_size)) // splits
    bit_mask = (1 << bit_shift) - 1

    split_tokens = []
    for i in range(splits):
        split_tokens.append((tokens & (bit_mask << (i * bit_shift))) >> (i * bit_shift))

    return torch.stack(split_tokens, dim=2)


if __name__ == "__main__":
    tokens = torch.randint(0, 1023, (1, 16))
    split_tokens = split_factorized_tokens(tokens, 1024, 1)

    assert split_tokens.shape == (1, 16, 1)
    assert split_tokens.dtype == torch.int64

    combined_tokens = combine_factorized_tokens(split_tokens, 1024, 1)

    assert (tokens == combined_tokens).all()

    split_tokens = split_factorized_tokens(tokens, 1024, 2)
    combined_tokens = combine_factorized_tokens(split_tokens, 1024, 2)

    assert split_tokens.shape == (1, 16, 2)
    assert (tokens == combined_tokens).all(), f"{tokens} != {combined_tokens}"

    assert (torch.bitwise_right_shift(tokens, 5) == split_tokens[..., 1]).all()
    assert (tokens & 31 == split_tokens[..., 0]).all()