import random from math import ceil from functools import partial from itertools import zip_longest from random import randrange import torch from torch import nn import torch.nn.functional as F # from vector_quantize_pytorch.vector_quantize_pytorch import VectorQuantize from models.vq.quantizer import QuantizeEMAReset, QuantizeEMA from einops import rearrange, repeat, pack, unpack # helper functions def exists(val): return val is not None def default(val, d): return val if exists(val) else d def round_up_multiple(num, mult): return ceil(num / mult) * mult # main class class ResidualVQ(nn.Module): """ Follows Algorithm 1. in https://arxiv.org/pdf/2107.03312.pdf """ def __init__( self, num_quantizers, shared_codebook=False, quantize_dropout_prob=0.5, quantize_dropout_cutoff_index=0, **kwargs ): super().__init__() self.num_quantizers = num_quantizers # self.layers = nn.ModuleList([VectorQuantize(accept_image_fmap = accept_image_fmap, **kwargs) for _ in range(num_quantizers)]) if shared_codebook: layer = QuantizeEMAReset(**kwargs) self.layers = nn.ModuleList([layer for _ in range(num_quantizers)]) else: self.layers = nn.ModuleList([QuantizeEMAReset(**kwargs) for _ in range(num_quantizers)]) # self.layers = nn.ModuleList([QuantizeEMA(**kwargs) for _ in range(num_quantizers)]) # self.quantize_dropout = quantize_dropout and num_quantizers > 1 assert quantize_dropout_cutoff_index >= 0 and quantize_dropout_prob >= 0 self.quantize_dropout_cutoff_index = quantize_dropout_cutoff_index self.quantize_dropout_prob = quantize_dropout_prob @property def codebooks(self): codebooks = [layer.codebook for layer in self.layers] codebooks = torch.stack(codebooks, dim = 0) return codebooks # 'q c d' def get_codes_from_indices(self, indices): #indices shape 'b n q' # dequantize batch, quantize_dim = indices.shape[0], indices.shape[-1] # because of quantize dropout, one can pass in indices that are coarse # and the network should be able to reconstruct if quantize_dim < self.num_quantizers: indices = F.pad(indices, (0, self.num_quantizers - quantize_dim), value = -1) # get ready for gathering codebooks = repeat(self.codebooks, 'q c d -> q b c d', b = batch) gather_indices = repeat(indices, 'b n q -> q b n d', d = codebooks.shape[-1]) # take care of quantizer dropout mask = gather_indices == -1. gather_indices = gather_indices.masked_fill(mask, 0) # have it fetch a dummy code to be masked out later # print(gather_indices.max(), gather_indices.min()) all_codes = codebooks.gather(2, gather_indices) # gather all codes # mask out any codes that were dropout-ed all_codes = all_codes.masked_fill(mask, 0.) return all_codes # 'q b n d' def get_codebook_entry(self, indices): #indices shape 'b n q' all_codes = self.get_codes_from_indices(indices) #'q b n d' latent = torch.sum(all_codes, dim=0) #'b n d' latent = latent.permute(0, 2, 1) return latent def forward(self, x, return_all_codes = False, sample_codebook_temp = None, force_dropout_index=-1): # debug check # print(self.codebooks[:,0,0].detach().cpu().numpy()) num_quant, quant_dropout_prob, device = self.num_quantizers, self.quantize_dropout_prob, x.device quantized_out = 0. residual = x all_losses = [] all_indices = [] all_perplexity = [] should_quantize_dropout = self.training and random.random() < self.quantize_dropout_prob start_drop_quantize_index = num_quant # To ensure the first-k layers learn things as much as possible, we randomly dropout the last q - k layers if should_quantize_dropout: start_drop_quantize_index = randrange(self.quantize_dropout_cutoff_index, num_quant) # keep quant layers <= quantize_dropout_cutoff_index, TODO vary in batch null_indices_shape = [x.shape[0], x.shape[-1]] # 'b*n' null_indices = torch.full(null_indices_shape, -1., device = device, dtype = torch.long) # null_loss = 0. if force_dropout_index >= 0: should_quantize_dropout = True start_drop_quantize_index = force_dropout_index null_indices_shape = [x.shape[0], x.shape[-1]] # 'b*n' null_indices = torch.full(null_indices_shape, -1., device=device, dtype=torch.long) # print(force_dropout_index) # go through the layers for quantizer_index, layer in enumerate(self.layers): if should_quantize_dropout and quantizer_index > start_drop_quantize_index: all_indices.append(null_indices) # all_losses.append(null_loss) continue # layer_indices = None # if return_loss: # layer_indices = indices[..., quantizer_index] #gt indices # quantized, *rest = layer(residual, indices = layer_indices, sample_codebook_temp = sample_codebook_temp) #single quantizer TODO quantized, *rest = layer(residual, return_idx=True, temperature=sample_codebook_temp) #single quantizer # print(quantized.shape, residual.shape) residual -= quantized.detach() quantized_out += quantized embed_indices, loss, perplexity = rest all_indices.append(embed_indices) all_losses.append(loss) all_perplexity.append(perplexity) # stack all losses and indices all_indices = torch.stack(all_indices, dim=-1) all_losses = sum(all_losses)/len(all_losses) all_perplexity = sum(all_perplexity)/len(all_perplexity) ret = (quantized_out, all_indices, all_losses, all_perplexity) if return_all_codes: # whether to return all codes from all codebooks across layers all_codes = self.get_codes_from_indices(all_indices) # will return all codes in shape (quantizer, batch, sequence length, codebook dimension) ret = (*ret, all_codes) return ret def quantize(self, x, return_latent=False): all_indices = [] quantized_out = 0. residual = x all_codes = [] for quantizer_index, layer in enumerate(self.layers): quantized, *rest = layer(residual, return_idx=True) #single quantizer residual = residual - quantized.detach() quantized_out = quantized_out + quantized embed_indices, loss, perplexity = rest all_indices.append(embed_indices) # print(quantizer_index, embed_indices[0]) # print(quantizer_index, quantized[0]) # break all_codes.append(quantized) code_idx = torch.stack(all_indices, dim=-1) all_codes = torch.stack(all_codes, dim=0) if return_latent: return code_idx, all_codes return code_idx