from typing import List, Optional, Sequence, Tuple, Union import numpy as np import torch from torch import distributed as tdist, nn as nn from torch.nn import functional as F from math import sqrt import math from einops import rearrange, reduce, pack, unpack import dist def mult_along_first_dims(x, y): """ returns x * y elementwise along the leading dimensions of y """ ndim_to_expand = x.ndim - y.ndim for _ in range(ndim_to_expand): y = y.unsqueeze(-1) return x * y def masked_mean(x, m): """ takes the mean of the elements of x that are not masked the mean is taken along the shared leading dims of m equivalent to: x[m].mean(tuple(range(m.ndim))) The benefit of using masked_mean rather than using tensor indexing is that masked_mean is much faster for torch-compile on batches. The drawback is larger floating point errors """ x = mult_along_first_dims(x, m) x = x / m.sum() return x.sum(tuple(range(m.ndim))) def entropy_loss( logits, mask=None, temperature=0.01, sample_minimization_weight=1.0, batch_maximization_weight=1.0, eps=1e-5, ): """ Entropy loss of unnormalized logits logits: Affinities are over the last dimension https://github.com/google-research/magvit/blob/05e8cfd6559c47955793d70602d62a2f9b0bdef5/videogvt/train_lib/losses.py#L279 LANGUAGE MODEL BEATS DIFFUSION — TOKENIZER IS KEY TO VISUAL GENERATION (2024) """ probs = F.softmax(logits / temperature, -1) log_probs = F.log_softmax(logits / temperature + eps, -1) if mask is not None: # avg_probs = probs[mask].mean(tuple(range(probs.ndim - 1))) # avg_probs = einx.mean("... D -> D", probs[mask]) avg_probs = reduce(masked_mean(probs, mask), "... D -> D", "mean") # avg_probs = einx.mean("... D -> D", avg_probs) else: avg_probs = reduce(probs, "... D -> D", "mean") avg_entropy = -torch.sum(avg_probs * torch.log(avg_probs + eps)) sample_entropy = -torch.sum(probs * log_probs, -1) if mask is not None: # sample_entropy = sample_entropy[mask].mean() sample_entropy = masked_mean(sample_entropy, mask).mean() else: sample_entropy = torch.mean(sample_entropy) loss = (sample_minimization_weight * sample_entropy) - ( batch_maximization_weight * avg_entropy ) return sample_entropy, avg_entropy, loss class LFQ(nn.Module): # VQGAN originally use beta=1.0, never tried 0.25; SD seems using 0.25 def __init__( self, codebook_size, Cvae, using_znorm=False, beta: float = 0.25, default_qresi_counts=0, v_patch_nums=None, quant_resi=0.5, share_quant_resi=4, num_latent_tokens=256, codebook_drop=0.0, scale=1, sample_minimization_weight=1.0, batch_maximization_weight=1.0, entropy_weight=0.1, soft_entropy=True, # share_quant_resi: args.qsr ): super().__init__() self.Cvae: int = Cvae self.vocab_size: int = 2**self.Cvae assert self.vocab_size == codebook_size self.using_znorm: bool = using_znorm self.v_patch_nums: Tuple[int] = v_patch_nums self.num_latent_tokens = num_latent_tokens self.entropy_weight = entropy_weight self.soft_entropy = soft_entropy self.persample_entropy_compute = "analytical" self.quant_resi_ratio = quant_resi if share_quant_resi == 0: # non-shared: \phi_{1 to K} for K scales self.quant_resi = PhiNonShared( [ (Phi(Cvae, quant_resi) if abs(quant_resi) > 1e-6 else nn.Identity()) for _ in range(default_qresi_counts or len(self.v_patch_nums)) ] ) elif share_quant_resi == 1: # fully shared: only a single \phi for K scales self.quant_resi = PhiShared( Phi(Cvae, quant_resi) if abs(quant_resi) > 1e-6 else nn.Identity() ) else: # partially shared: \phi_{1 to share_quant_resi} for K scales self.quant_resi = PhiPartiallyShared( nn.ModuleList( [ ( Phi(Cvae, quant_resi) if abs(quant_resi) > 1e-6 else nn.Identity() ) for _ in range(share_quant_resi) ] ) ) self.register_buffer( "ema_vocab_hit_SV", torch.full((len(self.v_patch_nums), self.vocab_size), fill_value=0.0), ) self.record_hit = 0 self.register_buffer("mask", 2 ** torch.arange(self.Cvae), persistent=False) self.beta: float = beta self.codebook_drop = codebook_drop scaler = scale ** torch.arange(len(self.v_patch_nums)) if using_znorm: scaler = scaler / sqrt(self.Cvae) self.register_buffer("scaler", scaler) print("scale is", scaler) # for entropy loss self.sample_minimization_weight = sample_minimization_weight self.batch_maximization_weight = batch_maximization_weight # codes all_codes = torch.arange(codebook_size) bits = self.indices_to_bits(all_codes) codebook = bits * 2.0 - 1.0 self.register_buffer("codebook", codebook, persistent=False) # only used for progressive training of VAR (not supported yet, will be tested and supported in the future) self.prog_si = -1 # progressive training: not supported yet, prog_si always -1 def extra_repr(self) -> str: return f"{self.v_patch_nums}, znorm={self.using_znorm}, beta={self.beta} | S={len(self.v_patch_nums)}, quant_resi={self.quant_resi_ratio}" # ===================== `forward` is only used in VAE training ===================== def forward( self, f_BChw: torch.Tensor, ret_usages=False, dropout=None ) -> Tuple[torch.Tensor, List[float], torch.Tensor]: dtype = f_BChw.dtype if dtype != torch.float32: f_BChw = f_BChw.float() B, C, H, W = f_BChw.shape if self.using_znorm: f_BChw = F.normalize(f_BChw, dim=1) f_no_grad = f_BChw.detach() f_rest = f_no_grad.clone() f_hat = torch.zeros_like(f_rest) # x = f_BChw with torch.cuda.amp.autocast(enabled=False): mean_vq_loss: torch.Tensor = 0.0 mean_commit_loss: torch.Tensor = 0.0 mean_entropy_loss: torch.Tensor = 0.0 vocab_hit_V = torch.zeros( self.vocab_size, dtype=torch.float, device=f_BChw.device ) SN = len(self.v_patch_nums) if self.training: max_n = len(self.v_patch_nums) + 1 n_quantizers = torch.ones((B,)) * max_n n_dropout = int(B * self.codebook_drop) n_quantizers[:n_dropout] = dropout[:n_dropout] n_quantizers = n_quantizers.to(f_BChw.device) else: n_quantizers = torch.ones((B,)) * (self.v_patch_nums + 1) for si, pn in enumerate(self.v_patch_nums): # from small to large codebook_value = ( self.scaler[si].to(device=f_BChw.device, dtype=torch.float).detach() ) # find the nearest embedding rest_NC = ( F.interpolate(f_rest, size=(pn, pn), mode="area") .permute(0, 2, 3, 1) .reshape(-1, C) if (si != SN - 1) or pn != int(sqrt(self.num_latent_tokens)) else f_rest.permute(0, 2, 3, 1).reshape(-1, C) ) # rest_NC = f_rest.permute(0, 2, 3, 1).reshape(-1, C) d_no_grad = torch.where(rest_NC > 0, codebook_value, -codebook_value) idx_N = self.bits_to_indices((d_no_grad > 0)) hit_V = idx_N.bincount(minlength=self.vocab_size).float() if self.training: handler = tdist.all_reduce(hit_V, async_op=True) # calc loss idx_Bhw = idx_N.view(B, pn, pn) h_BChw = ( F.interpolate( self.indices_to_bits(idx_Bhw, si).permute(0, 3, 1, 2), size=(H, W), mode="bicubic", ).contiguous() if (si != SN - 1) else self.indices_to_bits(idx_Bhw, si) .permute(0, 3, 1, 2) .contiguous() ) # h_BChw = self.indices_to_bits(idx_Bhw, si).permute(0, 3, 1, 2).contiguous() h_BChw = self.quant_resi[si / (SN - 1)](h_BChw) # x = f_rest.clone().permute(0, 2, 3, 1) x = rearrange((f_BChw - f_hat.detach()), "b d h w -> b (h w) 1 d") mask = ( torch.full((B,), fill_value=si, device=h_BChw.device) < n_quantizers )[:, None, None, None].int() f_hat = f_hat + h_BChw * mask f_rest -= h_BChw if self.training: handler.wait() if self.record_hit == 0: self.ema_vocab_hit_SV[si].copy_(hit_V) elif self.record_hit < 100: self.ema_vocab_hit_SV[si].mul_(0.9).add_(hit_V.mul(0.1)) else: self.ema_vocab_hit_SV[si].mul_(0.99).add_(hit_V.mul(0.01)) self.record_hit += 1 vocab_hit_V.add_(hit_V) ratio = mask.sum() / B codebook = self.codebook * codebook_value if self.soft_entropy: per_sample_entropy, codebook_entropy, avg_prob = ( self.soft_entropy_loss(x, si, codebook, mask.squeeze()) ) entropy_aux_loss = ( self.sample_minimization_weight * per_sample_entropy ) - (self.batch_maximization_weight * codebook_entropy) else: logits = 2 * torch.einsum("... i d, j d -> ... i j", x, codebook) # the same as euclidean distance up to a constant per_sample_entropy, codebook_entropy, entropy_aux_loss = ( entropy_loss( logits=logits, mask=mask.squeeze(), sample_minimization_weight=self.sample_minimization_weight, batch_maximization_weight=self.batch_maximization_weight, ) ) # F.mse_loss(f_hat, f_no_grad, reduction="none").mul_(mask).mean() / ratio mean_vq_loss += ( F.mse_loss(f_hat, f_no_grad, reduction="none").mul_(mask).mean() / ratio ) mean_commit_loss += ( F.mse_loss(f_hat.data, f_BChw, reduction="none") .mul_(mask) .mul_(self.beta / ratio) .mean() ) entropy_weight = self.entropy_weight / ratio mean_entropy_loss += entropy_aux_loss.mul_(entropy_weight) # x -= h_BChw.detach() mean_vq_loss *= 1.0 / SN mean_commit_loss *= 1.0 / SN mean_entropy_loss *= 1.0 / SN f_hat = (f_hat.data - f_no_grad).add_(f_BChw) margin = ( tdist.get_world_size() * (f_BChw.numel() / f_BChw.shape[1]) / self.vocab_size * 0.08 ) # margin = pn*pn / 100 if ret_usages: usages = [ (self.ema_vocab_hit_SV[si] >= margin).float().mean().item() * 100 for si, pn in enumerate(self.v_patch_nums) ] else: usages = None return f_hat, usages, mean_vq_loss, mean_commit_loss, mean_entropy_loss # ===================== `forward` is only used in VAE training ===================== def bits_to_indices(self, bits): """ bits: bool tensor of big endian bits, where the last dimension is the bit dimension returns indices, which are long integers from 0 to self.codebook_size """ assert bits.shape[-1] == self.Cvae indices = 2 ** torch.arange( 0, self.Cvae, 1, dtype=torch.long, device=bits.device, ) return (bits * indices).sum(-1) def indices_to_bits(self, x, si=None): """ x: long tensor of indices returns big endian bits """ mask = 2 ** torch.arange(self.Cvae, device=x.device, dtype=torch.long) # x is now big endian bits, the last dimension being the bits x = (x.unsqueeze(-1) & mask) != 0 if si == None: return x return torch.where(x, self.scaler[si], -self.scaler[si]) def soft_entropy_loss(self, z, si, codebook, mask=None): if mask != None: z = z[mask] distance = -2 * torch.einsum("... g c, d c ->... g d", z, codebook) prob = (-distance).softmax(dim=-1) if self.persample_entropy_compute == "analytical": p = torch.sigmoid(-4 * z * (self.scaler[si])) prob = torch.stack([p, 1 - p], dim=-1) per_sample_entropy = ( self.get_entropy(prob, dim=-1, normalize=False).sum(dim=-1).mean() ) else: per_sample_entropy = ( self.get_entropy(prob, dim=-1, normalize=False).sum(dim=-1).mean() ) # macro average of the probability of each subgroup avg_prob = reduce(prob, "... g d ->g d", "mean") codebook_entropy = self.get_entropy(avg_prob, dim=-1, normalize=False) # the approximation of the entropy is the sum of the entropy of each subgroup return per_sample_entropy, codebook_entropy.sum(), avg_prob def get_entropy(self, count, dim=-1, eps=1e-4, normalize=True): if normalize: probs = (count + eps) / (count + eps).sum(dim=dim, keepdim=True) else: probs = count H = -(probs * torch.log(probs + 1e-8)).sum(dim=dim) return H def embed_to_fhat( self, ms_h_BChw: List[torch.Tensor], all_to_max_scale=True, last_one=False ) -> Union[List[torch.Tensor], torch.Tensor]: ls_f_hat_BChw = [] B = ms_h_BChw[0].shape[0] H = W = self.v_patch_nums[-1] SN = len(self.v_patch_nums) if all_to_max_scale: f_hat = ms_h_BChw[0].new_zeros(B, self.Cvae, H, W, dtype=torch.float32) for si, pn in enumerate(self.v_patch_nums): # from small to large h_BChw = ms_h_BChw[si] if si < len(self.v_patch_nums) - 1: h_BChw = F.interpolate(h_BChw, size=(H, W), mode="bicubic") h_BChw = self.quant_resi[si / (SN - 1)](h_BChw) f_hat.add_(h_BChw) if last_one: ls_f_hat_BChw = f_hat else: ls_f_hat_BChw.append(f_hat.clone()) else: # WARNING: this is not the case in VQ-VAE training or inference (we'll interpolate every token map to the max H W, like above) # WARNING: this should only be used for experimental purpose f_hat = ms_h_BChw[0].new_zeros( B, self.Cvae, self.v_patch_nums[0], self.v_patch_nums[0], dtype=torch.float32, ) for si, pn in enumerate(self.v_patch_nums): # from small to large f_hat = F.interpolate(f_hat, size=(pn, pn), mode="bicubic") h_BChw = self.quant_resi[si / (SN - 1)](ms_h_BChw[si]) f_hat.add_(h_BChw) if last_one: ls_f_hat_BChw = f_hat else: ls_f_hat_BChw.append(f_hat) return ls_f_hat_BChw def f_to_idxBl_or_fhat( self, f_BChw: torch.Tensor, to_fhat: bool, v_patch_nums: Optional[Sequence[Union[int, Tuple[int, int]]]] = None, ) -> List[ Union[torch.Tensor, torch.LongTensor] ]: # z_BChw is the feature from inp_img_no_grad B, C, H, W = f_BChw.shape if self.using_znorm: f_BChw = F.normalize(f_BChw, dim=1) f_no_grad = f_BChw.detach() f_rest = f_no_grad.clone() f_hat = torch.zeros_like(f_rest) f_hat_or_idx_Bl: List[torch.Tensor] = [] patch_hws = [ (pn, pn) if isinstance(pn, int) else (pn[0], pn[1]) for pn in (v_patch_nums or self.v_patch_nums) ] # from small to large # assert patch_hws[-1][0] == H and patch_hws[-1][1] == W, f'{patch_hws[-1]=} != ({H=}, {W=})' SN = len(patch_hws) for si, (ph, pw) in enumerate(patch_hws): # from small to large codebook_value = ( self.scaler[si].to(device=f_BChw.device, dtype=torch.float).detach() ) if 0 <= self.prog_si < si: break # progressive training: not supported yet, prog_si always -1 # find the nearest embedding z_NC = ( F.interpolate(f_rest, size=(ph, pw), mode="area") .permute(0, 2, 3, 1) .reshape(-1, C) if (si != SN - 1) or ph != 16 else f_rest.permute(0, 2, 3, 1).reshape(-1, C) ) d_no_grad = torch.where(z_NC > 0, codebook_value, -codebook_value) idx_N = self.bits_to_indices((d_no_grad > 0)) idx_Bhw = idx_N.view(B, ph, pw) h_BChw = ( F.interpolate( self.indices_to_bits(idx_Bhw, si).permute(0, 3, 1, 2), size=(H, W), mode="bicubic", ).contiguous() if (si != SN - 1) else self.indices_to_bits(idx_Bhw, si).permute(0, 3, 1, 2).contiguous() ) h_BChw = self.quant_resi[si / (SN - 1)](h_BChw) f_hat.add_(h_BChw) f_rest.sub_(h_BChw) f_hat_or_idx_Bl.append( f_hat.clone() if to_fhat else idx_N.reshape(B, ph * pw) ) return f_hat_or_idx_Bl # ===================== idxBl_to_var_input: only used in VAR training, for getting teacher-forcing input ===================== def idxBl_to_var_input(self, gt_ms_idx_Bl: List[torch.Tensor]) -> torch.Tensor: next_scales = [] B = gt_ms_idx_Bl[0].shape[0] C = self.Cvae H = W = self.v_patch_nums[-1] SN = len(self.v_patch_nums) f_hat = gt_ms_idx_Bl[0].new_zeros(B, C, H, W, dtype=torch.float32) pn_next: int = self.v_patch_nums[0] for si in range(SN - 1): if self.prog_si == 0 or (0 <= self.prog_si - 1 < si): break # progressive training: not supported yet, prog_si always -1 h_BChw = F.interpolate( self.embedding(gt_ms_idx_Bl[si]) .transpose_(1, 2) .view(B, C, pn_next, pn_next), size=(H, W), mode="bicubic", ) f_hat.add_(self.quant_resi[si / (SN - 1)](h_BChw)) pn_next = self.v_patch_nums[si + 1] next_scales.append( F.interpolate(f_hat, size=(pn_next, pn_next), mode="area") .view(B, C, -1) .transpose(1, 2) ) return ( torch.cat(next_scales, dim=1) if len(next_scales) else None ) # cat BlCs to BLC, this should be float32 # ===================== get_next_autoregressive_input: only used in VAR inference, for getting next step's input ===================== def get_next_autoregressive_input( self, si: int, SN: int, f_hat: torch.Tensor, h_BChw: torch.Tensor ) -> Tuple[Optional[torch.Tensor], torch.Tensor]: # only used in VAR inference HW = self.v_patch_nums[-1] if si != SN - 1: h = self.quant_resi[si / (SN - 1)]( F.interpolate(h_BChw, size=(HW, HW), mode="bicubic") ) # conv after upsample f_hat.add_(h) return f_hat, F.interpolate( f_hat, size=(self.v_patch_nums[si + 1], self.v_patch_nums[si + 1]), mode="area", ) else: h = self.quant_resi[si / (SN - 1)](h_BChw) f_hat.add_(h) return f_hat, f_hat class Phi(nn.Conv2d): def __init__(self, embed_dim, quant_resi): ks = 3 super().__init__( in_channels=embed_dim, out_channels=embed_dim, kernel_size=ks, stride=1, padding=ks // 2, ) self.resi_ratio = abs(quant_resi) def forward(self, h_BChw): return h_BChw.mul(1 - self.resi_ratio) + super().forward(h_BChw).mul_( self.resi_ratio ) class PhiShared(nn.Module): def __init__(self, qresi: Phi): super().__init__() self.qresi: Phi = qresi def __getitem__(self, _) -> Phi: return self.qresi class PhiPartiallyShared(nn.Module): def __init__(self, qresi_ls: nn.ModuleList): super().__init__() self.qresi_ls = qresi_ls K = len(qresi_ls) self.ticks = ( np.linspace(1 / 3 / K, 1 - 1 / 3 / K, K) if K == 4 else np.linspace(1 / 2 / K, 1 - 1 / 2 / K, K) ) def __getitem__(self, at_from_0_to_1: float) -> Phi: return self.qresi_ls[np.argmin(np.abs(self.ticks - at_from_0_to_1)).item()] def extra_repr(self) -> str: return f"ticks={self.ticks}" class PhiNonShared(nn.ModuleList): def __init__(self, qresi: List): super().__init__(qresi) # self.qresi = qresi K = len(qresi) self.ticks = ( np.linspace(1 / 3 / K, 1 - 1 / 3 / K, K) if K == 4 else np.linspace(1 / 2 / K, 1 - 1 / 2 / K, K) ) def __getitem__(self, at_from_0_to_1: float) -> Phi: return super().__getitem__( np.argmin(np.abs(self.ticks - at_from_0_to_1)).item() ) def extra_repr(self) -> str: return f"ticks={self.ticks}" def schedule(ratio, total_unknown, method="cosine"): """Generates a mask rate by scheduling mask functions R. Given a ratio in [0, 1), we generate a masking ratio from (0, 1]. During training, the input ratio is uniformly sampled; during inference, the input ratio is based on the step number divided by the total iteration number: t/T. Based on experiements, we find that masking more in training helps. Args: ratio: The uniformly sampled ratio [0, 1) as input. total_unknown: The total number of tokens that can be masked out. For example, in MaskGIT, total_unknown = 256 for 256x256 images and 1024 for 512x512 images. method: implemented functions are ["uniform", "cosine", "pow", "log", "exp"] "pow2.5" represents x^2.5 Returns: The mask rate (float). """ if method == "uniform": mask_ratio = 1.0 - ratio elif "pow" in method: exponent = float(method.replace("pow", "")) mask_ratio = 1.0 - ratio**exponent elif method == "cosine": mask_ratio = np.cos(math.pi / 2.0 * ratio) elif method == "log": mask_ratio = -np.log2(ratio) / np.log2(total_unknown) elif method == "exp": mask_ratio = 1 - np.exp2(-np.log2(total_unknown) * (1 - ratio)) # Clamps mask into [epsilon, 1) mask_ratio = np.clip(mask_ratio, 0, 1.0) return mask_ratio if __name__ == "__main__": batch_size = 4 seq_len = 16 num_classes = 4096 # # Generate random logits and integer mask # logits = torch.randn(batch_size, seq_len,seq_len, num_classes) mask = torch.ones(batch_size, dtype=torch.int) # # Calculate entropy loss # sample_entropy, avg_entropy, loss = entropy_loss( # logits, # mask=mask, # sample_minimization_weight=1.0, # batch_maximization_weight=1.0, # ) # # Output results # print("Sample Entropy for mask:", sample_entropy) # print("Average Entropy for mask:", avg_entropy) # print("Entropy Loss for mask:", loss) # # Calculate entropy loss # sample_entropy, avg_entropy, loss = entropy_loss( # logits, # sample_minimization_weight=1.0, # batch_maximization_weight=1.0, # ) # # Output results # print("Sample Entropy:", sample_entropy) # print("Average Entropy:", avg_entropy) # print("Entropy Loss:", loss) quantizer = LFQ( 4096, 12, using_znorm=False, v_patch_nums=[1, 2, 3, 4, 5, 6, 8, 10, 12, 16], ) z = torch.randn(batch_size, seq_len * seq_len, 1, 12) for i in range(10): codebook = quantizer.codebook * quantizer.scaler[i] logits = 2 * torch.einsum("... i d, j d -> ... i j", z, codebook) per_sample_entropy, codebook_entropy, avg_prob = quantizer.soft_entropy_loss( z, i, codebook, mask ) print("Soft Sample Entropy :", per_sample_entropy) print("Soft codebook Entropy:", codebook_entropy) print("Soft Entropy Loss", per_sample_entropy - codebook_entropy) sample_entropy, avg_entropy, loss = entropy_loss( logits, mask=mask, sample_minimization_weight=1.0, batch_maximization_weight=1.0, ) print("Sample Entropy :", sample_entropy) print("codebook Entropy:", avg_entropy) print("Entropy Loss", loss) image_feats = torch.randn( 2, 12, 16, 16 ) # 16 is dim, must be power of 2 of codebook_size dropout_rand = torch.randint(3, len([1, 2, 3, 4, 5, 6, 8, 10, 12, 16]) + 1, (2,)) quantized, usgae, loss = quantizer( image_feats, ret_usages=True, dropout=dropout_rand ) # you may want to experiment with temperature assert image_feats.shape == quantized.shape