OptVQ / optvq /models /quantizer.py
BorelTHU's picture
initiate
223d932
# ------------------------------------------------------------------------------
# OptVQ: Preventing Local Pitfalls in Vector Quantization via Optimal Transport
# Copyright (c) 2024 Borui Zhang. All Rights Reserved.
# Licensed under the MIT License [see LICENSE for details]
# ------------------------------------------------------------------------------
import math
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
import torch.distributed as dist
import optvq.utils.logger as L
class VectorQuantizer(nn.Module):
def __init__(self, n_e: int = 1024, e_dim: int = 128,
beta: float = 1.0, use_norm: bool = False,
use_proj: bool = True, fix_codes: bool = False,
loss_q_type: str = "ce",
num_head: int = 1,
start_quantize_steps: int = None):
super(VectorQuantizer, self).__init__()
self.n_e = n_e
self.e_dim = e_dim
self.beta = beta
self.loss_q_type = loss_q_type
self.num_head = num_head
self.start_quantize_steps = start_quantize_steps
self.code_dim = self.e_dim // self.num_head
self.norm = lambda x: F.normalize(x, p=2.0, dim=-1, eps=1e-6) if use_norm else x
assert not use_norm, f"use_norm=True is no longer supported! Because the norm operation without theorectical analysis may cause unpredictable unstability."
self.use_proj = use_proj
self.embedding = nn.Embedding(num_embeddings=n_e, embedding_dim=self.code_dim)
if use_proj:
self.proj = nn.Linear(self.code_dim, self.code_dim)
torch.nn.init.normal_(self.proj.weight, std=self.code_dim ** -0.5)
if fix_codes:
self.embedding.weight.requires_grad = False
def reshape_input(self, x: Tensor):
"""
(B, C, H, W) / (B, T, C) -> (B, T, C)
"""
if x.ndim == 4:
_, C, H, W = x.size()
x = x.permute(0, 2, 3, 1).contiguous().view(-1, H * W, C)
return x, {"size": (H, W)}
elif x.ndim == 3:
return x, None
else:
raise ValueError("Invalid input shape!")
def recover_output(self, x: Tensor, info):
if info is not None:
H, W = info["size"]
if x.ndim == 3: # features (B, T, C) -> (B, C, H, W)
C = x.size(2)
return x.view(-1, H, W, C).permute(0, 3, 1, 2).contiguous()
elif x.ndim == 2: # indices (B, T) -> (B, H, W)
return x.view(-1, H, W)
else:
raise ValueError("Invalid input shape!")
else: # features (B, T, C) or indices (B, T)
return x
def get_codebook(self, return_numpy: bool = True):
embed = self.proj(self.embedding.weight) if self.use_proj else self.embedding.weight
if return_numpy:
return embed.data.cpu().numpy()
else:
return embed.data
def quantize_input(self, query, reference):
# compute the distance matrix
query2ref = torch.cdist(query, reference, p=2.0) # (B1, B2)
# find the nearest embedding
indices = torch.argmin(query2ref, dim=-1) # (B1,)
nearest_ref = reference[indices] # (B1, C)
return indices, nearest_ref, query2ref
def compute_codebook_loss(self, query, indices, nearest_ref, beta: float, query2ref):
# compute the loss
if self.loss_q_type == "l2":
loss = torch.mean((query - nearest_ref.detach()).pow(2)) + \
torch.mean((nearest_ref - query.detach()).pow(2)) * beta
elif self.loss_q_type == "l1":
loss = torch.mean((query - nearest_ref.detach()).abs()) + \
torch.mean((nearest_ref - query.detach()).abs()) * beta
elif self.loss_q_type == "ce":
loss = F.cross_entropy(- query2ref, indices)
return loss
def compute_quantized_output(self, x, x_q):
if self.start_quantize_steps is not None:
if self.training and L.log.total_steps < self.start_quantize_steps:
L.log.add_scalar("params/quantize_ratio", 0.0)
return x
else:
L.log.add_scalar("params/quantize_ratio", 1.0)
return x + (x_q - x).detach()
else:
L.log.add_scalar("params/quantize_ratio", 1.0)
return x + (x_q - x).detach()
@torch.autocast(device_type="cuda", enabled=False)
def forward(self, x: Tensor):
"""
Quantize the input tensor x with the embedding table.
Args:
x (Tensor): input tensor with shape (B, C, H, W) or (B, T, C)
Returns:
(tuple) containing: (x_q, loss, indices)
"""
x = x.float()
x, info = self.reshape_input(x)
B, T, C = x.size()
x = x.view(-1, C) # (B * T, C)
embed = self.proj(self.embedding.weight) if self.use_proj else self.embedding.weight
# split the x if multi-head is used
if self.num_head > 1:
x = x.view(-1, self.code_dim) # (B * T * nH, dC)
# compute the distance between x and each embedding
x, embed = self.norm(x), self.norm(embed)
# compute losses
indices, x_q, query2ref = self.quantize_input(x, embed)
loss = self.compute_codebook_loss(
query=x, indices=indices, nearest_ref=x_q,
beta=self.beta, query2ref=query2ref
)
# compute statistics
if self.training and L.GET_STATS:
with torch.no_grad():
num_unique = torch.unique(indices).size(0)
x_norm_mean = torch.mean(x.norm(dim=-1))
embed_norm_mean = torch.mean(embed.norm(dim=-1))
diff_norm_mean = torch.mean((x_q - x).norm(dim=-1))
x2e_mean = query2ref.mean()
L.log.add_scalar("params/num_unique", num_unique)
L.log.add_scalar("params/x_norm", x_norm_mean.item())
L.log.add_scalar("params/embed_norm", embed_norm_mean.item())
L.log.add_scalar("params/diff_norm", diff_norm_mean.item())
L.log.add_scalar("params/x2e_mean", x2e_mean.item())
# compute the final x_q
x_q = self.compute_quantized_output(x, x_q).view(B, T, C)
indices = indices.view(B, T, self.num_head)
# for output
x_q = self.recover_output(x_q, info)
indices = self.recover_output(indices, info)
return x_q, loss, indices
def sinkhorn(cost: Tensor, n_iters: int = 3, epsilon: float = 1, is_distributed: bool = False):
"""
Sinkhorn algorithm.
Args:
cost (Tensor): shape with (B, K)
"""
Q = torch.exp(- cost * epsilon).t() # (K, B)
if is_distributed:
B = Q.size(1) * dist.get_world_size()
else:
B = Q.size(1)
K = Q.size(0)
# make the matrix sums to 1
sum_Q = torch.sum(Q)
if is_distributed:
dist.all_reduce(sum_Q)
Q /= (sum_Q + 1e-8)
for _ in range(n_iters):
# normalize each row: total weight per prototype must be 1/K
sum_of_rows = torch.sum(Q, dim=1, keepdim=True)
if is_distributed:
dist.all_reduce(sum_of_rows)
Q /= (sum_of_rows + 1e-8)
Q /= K
# normalize each column: total weight per sample must be 1/B
Q /= (torch.sum(Q, dim=0, keepdim=True) + 1e-8)
Q /= B
Q *= B # the columns must sum to 1 so that Q is an assignment
return Q.t() # (B, K)
class VectorQuantizerSinkhorn(VectorQuantizer):
def __init__(self, epsilon: float = 10.0, n_iters: int = 5,
normalize_mode: str = "all", use_prob: bool = True,
*args, **kwargs):
super(VectorQuantizerSinkhorn, self).__init__(*args, **kwargs)
self.epsilon = epsilon
self.n_iters = n_iters
self.normalize_mode = normalize_mode
self.use_prob = use_prob
def normalize(self, A, dim, mode="all"):
if mode == "all":
A = (A - A.mean()) / (A.std() + 1e-6)
A = A - A.min()
elif mode == "dim":
A = A / math.sqrt(dim)
elif mode == "null":
pass
return A
def quantize_input(self, query, reference):
# compute the distance matrix
query2ref = torch.cdist(query, reference, p=2.0) # (B1, B2)
# compute the assignment matrix
with torch.no_grad():
is_distributed = dist.is_initialized() and dist.get_world_size() > 1
normalized_cost = self.normalize(query2ref, dim=reference.size(1), mode=self.normalize_mode)
Q = sinkhorn(normalized_cost, n_iters=self.n_iters, epsilon=self.epsilon, is_distributed=is_distributed)
if self.use_prob:
# avoid the zero value problem
max_q_id = torch.argmax(Q, dim=-1)
Q[torch.arange(Q.size(0)), max_q_id] += 1e-8
indices = torch.multinomial(Q, num_samples=1).squeeze()
else:
indices = torch.argmax(Q, dim=-1)
nearest_ref = reference[indices]
if self.training and L.GET_STATS:
if L.log.total_steps % 1000 == 0:
L.log.add_histogram("params/normalized_cost", normalized_cost)
return indices, nearest_ref, query2ref
class Identity(VectorQuantizer):
@torch.autocast(device_type="cuda", enabled=False)
def forward(self, x: Tensor):
x = x.float()
loss_q = torch.tensor(0.0, device=x.device, dtype=x.dtype)
# compute statistics
if self.training and L.GET_STATS:
with torch.no_grad():
x_flatten, _ = self.reshape_input(x)
x_norm_mean = torch.mean(x_flatten.norm(dim=-1))
L.log.add_scalar("params/x_norm", x_norm_mean.item())
return x, loss_q, None