File size: 5,901 Bytes
616f571
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
import sys
from typing import Literal, Optional

import torch
import torch.nn as nn
import torch.nn.functional as F

from cube3d.model.transformers.norm import RMSNorm


class SphericalVectorQuantizer(nn.Module):
    def __init__(
        self,
        embed_dim: int,
        num_codes: int,
        width: Optional[int] = None,
        codebook_regularization: Literal["batch_norm", "kl"] = "batch_norm",
    ):
        """
        Initializes the SphericalVQ module.
        Args:
            embed_dim (int): The dimensionality of the embeddings.
            num_codes (int): The number of codes in the codebook.
            width (Optional[int], optional): The width of the input. Defaults to None.
        Raises:
            ValueError: If beta is not in the range [0, 1].
        """
        super().__init__()

        self.num_codes = num_codes

        self.codebook = nn.Embedding(num_codes, embed_dim)
        self.codebook.weight.data.uniform_(-1.0 / num_codes, 1.0 / num_codes)

        width = width or embed_dim
        if width != embed_dim:
            self.c_in = nn.Linear(width, embed_dim)
            self.c_x = nn.Linear(width, embed_dim)  # shortcut
            self.c_out = nn.Linear(embed_dim, width)
        else:
            self.c_in = self.c_out = self.c_x = nn.Identity()

        self.norm = RMSNorm(embed_dim, elementwise_affine=False)
        self.cb_reg = codebook_regularization
        if self.cb_reg == "batch_norm":
            self.cb_norm = nn.BatchNorm1d(embed_dim, track_running_stats=False)
        else:
            self.cb_weight = nn.Parameter(torch.ones([embed_dim]))
            self.cb_bias = nn.Parameter(torch.zeros([embed_dim]))
            self.cb_norm = lambda x: x.mul(self.cb_weight).add_(self.cb_bias)

    def get_codebook(self):
        """
        Retrieves the normalized codebook weights.
        This method applies a series of normalization operations to the
        codebook weights, ensuring they are properly scaled and normalized
        before being returned.
        Returns:
            torch.Tensor: The normalized weights of the codebook.
        """

        return self.norm(self.cb_norm(self.codebook.weight))

    @torch.no_grad()

    def lookup_codebook(self, q: torch.Tensor):
        """
        Perform a lookup in the codebook and process the result.
        This method takes an input tensor of indices, retrieves the corresponding
        embeddings from the codebook, and applies a transformation to the retrieved
        embeddings.
        Args:
            q (torch.Tensor): A tensor containing indices to look up in the codebook.
        Returns:
            torch.Tensor: The transformed embeddings retrieved from the codebook.
        """

        # normalize codebook
        z_q = F.embedding(q, self.get_codebook())
        z_q = self.c_out(z_q)
        return z_q

    @torch.no_grad()
    def lookup_codebook_latents(self, q: torch.Tensor):
        """
        Retrieves the latent representations from the codebook corresponding to the given indices.
        Args:
            q (torch.Tensor): A tensor containing the indices of the codebook entries to retrieve.
                              The indices should be integers and correspond to the rows in the codebook.
        Returns:
            torch.Tensor: A tensor containing the latent representations retrieved from the codebook.
                          The shape of the returned tensor depends on the shape of the input indices
                          and the dimensionality of the codebook entries.
        """

        # normalize codebook
        z_q = F.embedding(q, self.get_codebook())
        return z_q

    def quantize(self, z: torch.Tensor):
        """
        Quantizes the latent codes z with the codebook

        Args:
                z (Tensor): B x ... x F
        """

        # normalize codebook
        codebook = self.get_codebook()
        # the process of finding quantized codes is non differentiable
        with torch.no_grad():
            # flatten z
            z_flat = z.view(-1, z.shape[-1])

            # calculate distance and find the closest code
            d = torch.cdist(z_flat, codebook)
            q = torch.argmin(d, dim=1)  # num_ele

        z_q = codebook[q, :].reshape(*z.shape[:-1], -1)
        q = q.view(*z.shape[:-1])

        return z_q, {"z": z.detach(), "q": q}

    def straight_through_approximation(self, z, z_q):
        """passed gradient from z_q to z"""
        z_q = z + (z_q - z).detach()
        return z_q

    def forward(self, z: torch.Tensor):
        """
        Forward pass of the spherical vector quantization autoencoder.
        Args:
            z (torch.Tensor): Input tensor of shape (batch_size, ..., feature_dim).
        Returns:
            Tuple[torch.Tensor, Dict[str, Any]]:
                - z_q (torch.Tensor): The quantized output tensor after applying the
                  straight-through approximation and output projection.
                - ret_dict (Dict[str, Any]): A dictionary containing additional
                  information:
                    - "z_q" (torch.Tensor): Detached quantized tensor.
                    - "q" (torch.Tensor): Indices of the quantized vectors.
                    - "perplexity" (torch.Tensor): The perplexity of the quantization,
                      calculated as the exponential of the negative sum of the
                      probabilities' log values.
        """

        with torch.autocast(device_type=z.device.type, enabled=False):
            # work in full precision
            z = z.float()

            # project and normalize
            z_e = self.norm(self.c_in(z))
            z_q, ret_dict = self.quantize(z_e)

            ret_dict["z_q"] = z_q.detach()
            z_q = self.straight_through_approximation(z_e, z_q)
            z_q = self.c_out(z_q)

        return z_q, ret_dict