"""This file contains the definition of the our tokenizer, which can use VQ or LFQ.""" import math from typing import Mapping, Text, Tuple import torch from einops import rearrange from .modules import BaseModel, ConvDecoder, ConvDecoderLegacy, ConvEncoder from .quantizer import LookupFreeQuantizer, SimpleVectorizer def choose_vector_quantizer_class(config): if config.quantizer_type == "lookup": return SimpleVectorizer( config.codebook_size, config.token_size, config.commitment_cost, config.entropy_loss_weight, config.entropy_loss_temperature, config.entropy_gamma, config.get("use_l2_normalisation", False), ) elif config.quantizer_type == "lookup-free": return LookupFreeQuantizer( config.token_size, config.commitment_cost, config.entropy_loss_weight, config.entropy_loss_temperature, config.entropy_gamma, ) elif config.quantizer_type == "vae": return NotImplementedError( "Currently not supported. We welcome a well tested PR." ) else: raise ValueError("Unknown vector quantizer class") class ConvVQModel(BaseModel): def __init__(self, config, legacy: bool = False, finetune_decoder: bool = False): """Initializes the convolutional VQ-VAE model. Args: config: The configuration for the model. legacy -> bool: Whether to use the legacy decoder, which is a different implementation of the same architecture. finetune_decoder -> bool: Whether to finetune the decoder. """ super().__init__() self.config = config self.encoder = ConvEncoder(self.config) if legacy: # To support older weights and MaskGIT self.decoder = ConvDecoderLegacy(self.config) else: self.decoder = ConvDecoder(self.config) self.finetune_decoder = finetune_decoder if self.finetune_decoder: self.encoder.eval() self.encoder.requires_grad_(False) self.quantize = choose_vector_quantizer_class(self.config) def get_last_layer(self): return self.decoder.conv_out.weight def encode( self, x: torch.Tensor ) -> Tuple[torch.Tensor, Mapping[Text, torch.Tensor]]: """Encodes the input tensor, i.e. runs the encoder. Args: x -> torch.Tensor: The input tensor. Returns: z_quantized -> torch.Tensor: The quantized latent representation. result_dict -> Mapping[Text, torch.Tensor]: A dictionary containing additional results and losses from the quantizer. """ z = self.encoder(x) z_quantized, result_dict = self.quantize(z) return z_quantized, result_dict def decode(self, z_quantized: torch.Tensor) -> torch.Tensor: """Decodes the quantized latent representation, i.e. runs the decoder. Args: z_quantized -> torch.Tensor: The quantized latent representation. Returns: decoded -> torch.Tensor: The decoded image. """ decoded = self.decoder(z_quantized) return decoded def decode_tokens(self, tokens: torch.Tensor) -> torch.Tensor: """Decodes from tokens, i.e. runs the decoder after converting tokens to latent representations. Args: tokens -> torch.Tensor: The tokens. Returns: decoded -> torch.Tensor: The decoded image. """ z_quantized = self.quantize.get_codebook_entry(tokens) ss = int(math.sqrt(float(z_quantized.size(1)))) z_quantized = z_quantized.reshape(z_quantized.size(0), ss, ss, -1) z_quantized = rearrange(z_quantized, "b h w c -> b c h w").contiguous() decoded = self.decode(z_quantized) return decoded def forward( self, input: torch.Tensor ) -> Tuple[torch.Tensor, Mapping[Text, torch.Tensor]]: """Runs the model on the input tensor. Args: input -> torch.Tensor: The input image. Returns: decoded -> torch.Tensor: The decoded image. result_dict -> Mapping[Text, torch.Tensor]: A dictionary containing additional results and losses from the quantizer. """ if self.finetune_decoder: self.encoder.eval() z_quantized, result_dict = self._finetuning_encoder_forward(input) else: z_quantized, result_dict = self.encode(input) decoded = self.decode(z_quantized) return decoded, result_dict["min_encoding_indices"], z_quantized def _finetuning_encoder_forward( self, input: torch.Tensor ) -> Tuple[torch.Tensor, Mapping[Text, torch.Tensor]]: """Runs the encoder on the input tensor without gradients and sets quantizer losses to 0. Args: input -> torch.Tensor: The input image. Returns: z_quantized -> torch.Tensor: The quantized latent representation. result_dict -> Mapping[Text, torch.Tensor]: A dictionary containing additional results and losses from the quantizer. """ with torch.no_grad(): z_quantized, result_dict = self.encode(input) result_dict["quantizer_loss"] *= 0 result_dict["commitment_loss"] *= 0 if "codebook_loss" in result_dict: result_dict["codebook_loss"] *= 0 result_dict["entropy_loss"] *= 0 return z_quantized, result_dict