|
import math |
|
import torch |
|
from torch import nn |
|
from .factorized_vector_quantize import FactorizedVectorQuantize |
|
|
|
class ResidualVQ(nn.Module): |
|
def __init__( |
|
self, |
|
*, |
|
num_quantizers, |
|
codebook_size, |
|
**kwargs |
|
): |
|
super().__init__() |
|
VQ = FactorizedVectorQuantize |
|
if type(codebook_size) == int: |
|
codebook_size = [codebook_size] * num_quantizers |
|
self.layers = nn.ModuleList([VQ(codebook_size=size, **kwargs) for size in codebook_size]) |
|
self.num_quantizers = num_quantizers |
|
|
|
def forward(self, x): |
|
quantized_out = 0. |
|
residual = x |
|
|
|
all_losses = [] |
|
all_indices = [] |
|
|
|
for idx, layer in enumerate(self.layers): |
|
quantized, indices, loss = layer(residual) |
|
|
|
residual = residual - quantized |
|
|
|
quantized_out = quantized_out + quantized |
|
|
|
loss = loss.mean() |
|
|
|
all_indices.append(indices) |
|
all_losses.append(loss) |
|
all_losses, all_indices = map(torch.stack, (all_losses, all_indices)) |
|
return quantized_out, all_indices, all_losses |
|
|
|
def vq2emb(self, vq, proj=True): |
|
|
|
quantized_out = 0. |
|
for idx, layer in enumerate(self.layers): |
|
quantized = layer.vq2emb(vq[:, :, idx], proj=proj) |
|
quantized_out = quantized_out + quantized |
|
return quantized_out |
|
def get_emb(self): |
|
embs = [] |
|
for idx, layer in enumerate(self.layers): |
|
embs.append(layer.get_emb()) |
|
return embs |
|
|