import sys from typing import Literal, Optional import torch import torch.nn as nn import torch.nn.functional as F from cube3d.model.transformers.norm import RMSNorm class SphericalVectorQuantizer(nn.Module): def __init__( self, embed_dim: int, num_codes: int, width: Optional[int] = None, codebook_regularization: Literal["batch_norm", "kl"] = "batch_norm", ): """ Initializes the SphericalVQ module. Args: embed_dim (int): The dimensionality of the embeddings. num_codes (int): The number of codes in the codebook. width (Optional[int], optional): The width of the input. Defaults to None. Raises: ValueError: If beta is not in the range [0, 1]. """ super().__init__() self.num_codes = num_codes self.codebook = nn.Embedding(num_codes, embed_dim) self.codebook.weight.data.uniform_(-1.0 / num_codes, 1.0 / num_codes) width = width or embed_dim if width != embed_dim: self.c_in = nn.Linear(width, embed_dim) self.c_x = nn.Linear(width, embed_dim) # shortcut self.c_out = nn.Linear(embed_dim, width) else: self.c_in = self.c_out = self.c_x = nn.Identity() self.norm = RMSNorm(embed_dim, elementwise_affine=False) self.cb_reg = codebook_regularization if self.cb_reg == "batch_norm": self.cb_norm = nn.BatchNorm1d(embed_dim, track_running_stats=False) else: self.cb_weight = nn.Parameter(torch.ones([embed_dim])) self.cb_bias = nn.Parameter(torch.zeros([embed_dim])) self.cb_norm = lambda x: x.mul(self.cb_weight).add_(self.cb_bias) def get_codebook(self): """ Retrieves the normalized codebook weights. This method applies a series of normalization operations to the codebook weights, ensuring they are properly scaled and normalized before being returned. Returns: torch.Tensor: The normalized weights of the codebook. """ return self.norm(self.cb_norm(self.codebook.weight)) @torch.no_grad() def lookup_codebook(self, q: torch.Tensor): """ Perform a lookup in the codebook and process the result. This method takes an input tensor of indices, retrieves the corresponding embeddings from the codebook, and applies a transformation to the retrieved embeddings. Args: q (torch.Tensor): A tensor containing indices to look up in the codebook. Returns: torch.Tensor: The transformed embeddings retrieved from the codebook. """ # normalize codebook z_q = F.embedding(q, self.get_codebook()) z_q = self.c_out(z_q) return z_q @torch.no_grad() def lookup_codebook_latents(self, q: torch.Tensor): """ Retrieves the latent representations from the codebook corresponding to the given indices. Args: q (torch.Tensor): A tensor containing the indices of the codebook entries to retrieve. The indices should be integers and correspond to the rows in the codebook. Returns: torch.Tensor: A tensor containing the latent representations retrieved from the codebook. The shape of the returned tensor depends on the shape of the input indices and the dimensionality of the codebook entries. """ # normalize codebook z_q = F.embedding(q, self.get_codebook()) return z_q def quantize(self, z: torch.Tensor): """ Quantizes the latent codes z with the codebook Args: z (Tensor): B x ... x F """ # normalize codebook codebook = self.get_codebook() # the process of finding quantized codes is non differentiable with torch.no_grad(): # flatten z z_flat = z.view(-1, z.shape[-1]) # calculate distance and find the closest code d = torch.cdist(z_flat, codebook) q = torch.argmin(d, dim=1) # num_ele z_q = codebook[q, :].reshape(*z.shape[:-1], -1) q = q.view(*z.shape[:-1]) return z_q, {"z": z.detach(), "q": q} def straight_through_approximation(self, z, z_q): """passed gradient from z_q to z""" z_q = z + (z_q - z).detach() return z_q def forward(self, z: torch.Tensor): """ Forward pass of the spherical vector quantization autoencoder. Args: z (torch.Tensor): Input tensor of shape (batch_size, ..., feature_dim). Returns: Tuple[torch.Tensor, Dict[str, Any]]: - z_q (torch.Tensor): The quantized output tensor after applying the straight-through approximation and output projection. - ret_dict (Dict[str, Any]): A dictionary containing additional information: - "z_q" (torch.Tensor): Detached quantized tensor. - "q" (torch.Tensor): Indices of the quantized vectors. - "perplexity" (torch.Tensor): The perplexity of the quantization, calculated as the exponential of the negative sum of the probabilities' log values. """ with torch.autocast(device_type=z.device.type, enabled=False): # work in full precision z = z.float() # project and normalize z_e = self.norm(self.c_in(z)) z_q, ret_dict = self.quantize(z_e) ret_dict["z_q"] = z_q.detach() z_q = self.straight_through_approximation(z_e, z_q) z_q = self.c_out(z_q) return z_q, ret_dict