|
|
|
|
|
|
|
|
|
|
|
|
|
import math |
|
import numpy as np |
|
|
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from torch import Tensor |
|
|
|
import torch.distributed as dist |
|
|
|
import optvq.utils.logger as L |
|
|
|
class VectorQuantizer(nn.Module): |
|
def __init__(self, n_e: int = 1024, e_dim: int = 128, |
|
beta: float = 1.0, use_norm: bool = False, |
|
use_proj: bool = True, fix_codes: bool = False, |
|
loss_q_type: str = "ce", |
|
num_head: int = 1, |
|
start_quantize_steps: int = None): |
|
super(VectorQuantizer, self).__init__() |
|
self.n_e = n_e |
|
self.e_dim = e_dim |
|
self.beta = beta |
|
self.loss_q_type = loss_q_type |
|
self.num_head = num_head |
|
self.start_quantize_steps = start_quantize_steps |
|
self.code_dim = self.e_dim // self.num_head |
|
|
|
self.norm = lambda x: F.normalize(x, p=2.0, dim=-1, eps=1e-6) if use_norm else x |
|
assert not use_norm, f"use_norm=True is no longer supported! Because the norm operation without theorectical analysis may cause unpredictable unstability." |
|
self.use_proj = use_proj |
|
|
|
self.embedding = nn.Embedding(num_embeddings=n_e, embedding_dim=self.code_dim) |
|
if use_proj: |
|
self.proj = nn.Linear(self.code_dim, self.code_dim) |
|
torch.nn.init.normal_(self.proj.weight, std=self.code_dim ** -0.5) |
|
if fix_codes: |
|
self.embedding.weight.requires_grad = False |
|
|
|
def reshape_input(self, x: Tensor): |
|
""" |
|
(B, C, H, W) / (B, T, C) -> (B, T, C) |
|
""" |
|
if x.ndim == 4: |
|
_, C, H, W = x.size() |
|
x = x.permute(0, 2, 3, 1).contiguous().view(-1, H * W, C) |
|
return x, {"size": (H, W)} |
|
elif x.ndim == 3: |
|
return x, None |
|
else: |
|
raise ValueError("Invalid input shape!") |
|
|
|
def recover_output(self, x: Tensor, info): |
|
if info is not None: |
|
H, W = info["size"] |
|
if x.ndim == 3: |
|
C = x.size(2) |
|
return x.view(-1, H, W, C).permute(0, 3, 1, 2).contiguous() |
|
elif x.ndim == 2: |
|
return x.view(-1, H, W) |
|
else: |
|
raise ValueError("Invalid input shape!") |
|
else: |
|
return x |
|
|
|
def get_codebook(self, return_numpy: bool = True): |
|
embed = self.proj(self.embedding.weight) if self.use_proj else self.embedding.weight |
|
if return_numpy: |
|
return embed.data.cpu().numpy() |
|
else: |
|
return embed.data |
|
|
|
def quantize_input(self, query, reference): |
|
|
|
query2ref = torch.cdist(query, reference, p=2.0) |
|
|
|
|
|
indices = torch.argmin(query2ref, dim=-1) |
|
nearest_ref = reference[indices] |
|
|
|
return indices, nearest_ref, query2ref |
|
|
|
def compute_codebook_loss(self, query, indices, nearest_ref, beta: float, query2ref): |
|
|
|
if self.loss_q_type == "l2": |
|
loss = torch.mean((query - nearest_ref.detach()).pow(2)) + \ |
|
torch.mean((nearest_ref - query.detach()).pow(2)) * beta |
|
elif self.loss_q_type == "l1": |
|
loss = torch.mean((query - nearest_ref.detach()).abs()) + \ |
|
torch.mean((nearest_ref - query.detach()).abs()) * beta |
|
elif self.loss_q_type == "ce": |
|
loss = F.cross_entropy(- query2ref, indices) |
|
|
|
return loss |
|
|
|
def compute_quantized_output(self, x, x_q): |
|
if self.start_quantize_steps is not None: |
|
if self.training and L.log.total_steps < self.start_quantize_steps: |
|
L.log.add_scalar("params/quantize_ratio", 0.0) |
|
return x |
|
else: |
|
L.log.add_scalar("params/quantize_ratio", 1.0) |
|
return x + (x_q - x).detach() |
|
else: |
|
L.log.add_scalar("params/quantize_ratio", 1.0) |
|
return x + (x_q - x).detach() |
|
|
|
@torch.autocast(device_type="cuda", enabled=False) |
|
def forward(self, x: Tensor): |
|
""" |
|
Quantize the input tensor x with the embedding table. |
|
|
|
Args: |
|
x (Tensor): input tensor with shape (B, C, H, W) or (B, T, C) |
|
Returns: |
|
(tuple) containing: (x_q, loss, indices) |
|
""" |
|
x = x.float() |
|
x, info = self.reshape_input(x) |
|
B, T, C = x.size() |
|
x = x.view(-1, C) |
|
embed = self.proj(self.embedding.weight) if self.use_proj else self.embedding.weight |
|
|
|
|
|
if self.num_head > 1: |
|
x = x.view(-1, self.code_dim) |
|
|
|
|
|
x, embed = self.norm(x), self.norm(embed) |
|
|
|
|
|
indices, x_q, query2ref = self.quantize_input(x, embed) |
|
loss = self.compute_codebook_loss( |
|
query=x, indices=indices, nearest_ref=x_q, |
|
beta=self.beta, query2ref=query2ref |
|
) |
|
|
|
|
|
if self.training and L.GET_STATS: |
|
with torch.no_grad(): |
|
num_unique = torch.unique(indices).size(0) |
|
x_norm_mean = torch.mean(x.norm(dim=-1)) |
|
embed_norm_mean = torch.mean(embed.norm(dim=-1)) |
|
diff_norm_mean = torch.mean((x_q - x).norm(dim=-1)) |
|
x2e_mean = query2ref.mean() |
|
L.log.add_scalar("params/num_unique", num_unique) |
|
L.log.add_scalar("params/x_norm", x_norm_mean.item()) |
|
L.log.add_scalar("params/embed_norm", embed_norm_mean.item()) |
|
L.log.add_scalar("params/diff_norm", diff_norm_mean.item()) |
|
L.log.add_scalar("params/x2e_mean", x2e_mean.item()) |
|
|
|
|
|
x_q = self.compute_quantized_output(x, x_q).view(B, T, C) |
|
indices = indices.view(B, T, self.num_head) |
|
|
|
|
|
x_q = self.recover_output(x_q, info) |
|
indices = self.recover_output(indices, info) |
|
|
|
return x_q, loss, indices |
|
|
|
def sinkhorn(cost: Tensor, n_iters: int = 3, epsilon: float = 1, is_distributed: bool = False): |
|
""" |
|
Sinkhorn algorithm. |
|
Args: |
|
cost (Tensor): shape with (B, K) |
|
""" |
|
Q = torch.exp(- cost * epsilon).t() |
|
if is_distributed: |
|
B = Q.size(1) * dist.get_world_size() |
|
else: |
|
B = Q.size(1) |
|
K = Q.size(0) |
|
|
|
|
|
sum_Q = torch.sum(Q) |
|
if is_distributed: |
|
dist.all_reduce(sum_Q) |
|
Q /= (sum_Q + 1e-8) |
|
|
|
for _ in range(n_iters): |
|
|
|
sum_of_rows = torch.sum(Q, dim=1, keepdim=True) |
|
if is_distributed: |
|
dist.all_reduce(sum_of_rows) |
|
Q /= (sum_of_rows + 1e-8) |
|
Q /= K |
|
|
|
|
|
Q /= (torch.sum(Q, dim=0, keepdim=True) + 1e-8) |
|
Q /= B |
|
|
|
Q *= B |
|
return Q.t() |
|
|
|
class VectorQuantizerSinkhorn(VectorQuantizer): |
|
def __init__(self, epsilon: float = 10.0, n_iters: int = 5, |
|
normalize_mode: str = "all", use_prob: bool = True, |
|
*args, **kwargs): |
|
super(VectorQuantizerSinkhorn, self).__init__(*args, **kwargs) |
|
self.epsilon = epsilon |
|
self.n_iters = n_iters |
|
self.normalize_mode = normalize_mode |
|
self.use_prob = use_prob |
|
|
|
def normalize(self, A, dim, mode="all"): |
|
if mode == "all": |
|
A = (A - A.mean()) / (A.std() + 1e-6) |
|
A = A - A.min() |
|
elif mode == "dim": |
|
A = A / math.sqrt(dim) |
|
elif mode == "null": |
|
pass |
|
return A |
|
|
|
def quantize_input(self, query, reference): |
|
|
|
query2ref = torch.cdist(query, reference, p=2.0) |
|
|
|
|
|
with torch.no_grad(): |
|
is_distributed = dist.is_initialized() and dist.get_world_size() > 1 |
|
normalized_cost = self.normalize(query2ref, dim=reference.size(1), mode=self.normalize_mode) |
|
Q = sinkhorn(normalized_cost, n_iters=self.n_iters, epsilon=self.epsilon, is_distributed=is_distributed) |
|
|
|
if self.use_prob: |
|
|
|
max_q_id = torch.argmax(Q, dim=-1) |
|
Q[torch.arange(Q.size(0)), max_q_id] += 1e-8 |
|
indices = torch.multinomial(Q, num_samples=1).squeeze() |
|
else: |
|
indices = torch.argmax(Q, dim=-1) |
|
nearest_ref = reference[indices] |
|
|
|
if self.training and L.GET_STATS: |
|
if L.log.total_steps % 1000 == 0: |
|
L.log.add_histogram("params/normalized_cost", normalized_cost) |
|
|
|
return indices, nearest_ref, query2ref |
|
|
|
class Identity(VectorQuantizer): |
|
@torch.autocast(device_type="cuda", enabled=False) |
|
def forward(self, x: Tensor): |
|
x = x.float() |
|
loss_q = torch.tensor(0.0, device=x.device, dtype=x.dtype) |
|
|
|
|
|
if self.training and L.GET_STATS: |
|
with torch.no_grad(): |
|
x_flatten, _ = self.reshape_input(x) |
|
x_norm_mean = torch.mean(x_flatten.norm(dim=-1)) |
|
L.log.add_scalar("params/x_norm", x_norm_mean.item()) |
|
|
|
return x, loss_q, None |