File size: 3,855 Bytes
246c106
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
import torch
import torch.nn as nn
from einops import rearrange


class FactorizedEmbedding(nn.Module):
    """
    Each token's embedding is the sum of the embeddings in each factorized vocabulary.
    Equivalent to nn.Embedding when `num_factored_vocabs` = 1.
    """
    def __init__(self, factored_vocab_size: int, num_factored_vocabs: int, d_model: int, mask_token_id: int):
        """

        Args:
            config: Should specify `factored_vocab_size`, `d_model`, `num_factored_vocabs`, `image_vocab_size`.
                E.g. genie.config.GenieConfig
        """
        super().__init__()

        self.factored_vocab_size = factored_vocab_size
        self.num_factored_vocabs = num_factored_vocabs
        self.d_model = d_model
        self.mask_token_id = mask_token_id

        self.factored_embeds = nn.ParameterList([nn.Embedding(factored_vocab_size, d_model)
                                                 for _ in range(num_factored_vocabs)])
        self.mask_token_embed = nn.Parameter(torch.zeros(1, d_model))

    def forward(self, input_ids: torch.LongTensor) -> torch.FloatTensor:
        """

        Args:
            input_ids: Shape (B, T, H*W)

        Returns:
            input embeddings: Shape (B, T, H*W, d_model)
        """
        # initialize all embeddings to the mask token embedding, and then fill in actual token embeddings
        embeds = self.mask_token_embed.repeat(input_ids.size() + (1,))
        is_not_mask = input_ids != self.mask_token_id

        factored_token_ids = factorize_token_ids(
            input_ids[is_not_mask], self.num_factored_vocabs, self.factored_vocab_size
        )

        unmasked_embeds = [
            factored_embed(factored_token_ids)
            for factored_embed, factored_token_ids in zip(self.factored_embeds, factored_token_ids.unbind(-1))
        ]

        embeds[is_not_mask] = torch.sum(torch.stack(unmasked_embeds), dim=0)
        return embeds


def factorize_token_ids(
    token_ids: torch.LongTensor,
    num_factored_vocabs: int = 2,
    factored_vocab_size: int = 512
) -> torch.LongTensor:
    """
    `token_ids`: any size tensor with token id values in [0, image_vocab_size = 2**18).

    Returns:
        Size token_ids.size() + (num_factored_vocabs,), where the last dimension has token ids in
        each individual vocabulary, with values in [0, factored_vocab_size = 512)
    """
    powers = factored_vocab_size ** torch.arange(num_factored_vocabs, device=token_ids.device)
    return (token_ids.unsqueeze(-1) // powers) % factored_vocab_size


def unfactorize_token_ids(
    factored_token_ids: torch.LongTensor,
    num_factored_vocabs: int = 2,
    factored_vocab_size: int = 512
) -> torch.LongTensor:
    """
    Inverse of `factorize_token_ids`.
    It is assumed that the last dimension of `factored_token_ids` is the vocabulary dimension.

    Returns:
        Size token_ids.size()[:-1], with values in [0, image_vocab_size = 2**18)
    """
    powers = factored_vocab_size ** torch.arange(num_factored_vocabs, device=factored_token_ids.device)
    return (factored_token_ids * powers).sum(dim=-1)


def factorize_labels(
    labels_THW: torch.LongTensor,
    num_factored_vocabs: int = 2,
    factored_vocab_size: int = 512
) -> torch.LongTensor:
    """
    Simply `factorize_token_ids` followed by permuting dimensions.
    labels_THW: shape (B, T, H, W), values in [0, image_vocab_size=2**18)

    Returns:
        factored_labels: shape (B, num_factored_vocabs=2, T, H, W), values in [0, factored_vocab_size=512)
    """
    factored_labels = factorize_token_ids(labels_THW, num_factored_vocabs, factored_vocab_size)
    return rearrange(factored_labels, "b t h w num_factored_vocabs -> b num_factored_vocabs t h w")


def nth_root(x, n):
    root = round(x ** (1 / n))
    assert root ** n == x, (x, n, root)
    return root