|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import torch.nn as nn |
|
|
|
from repcodec.layers.vq_module import ResidualVQ |
|
|
|
|
|
class Quantizer(nn.Module): |
|
def __init__( |
|
self, |
|
code_dim: int, |
|
codebook_num: int, |
|
codebook_size: int, |
|
): |
|
super().__init__() |
|
self.codebook = ResidualVQ( |
|
dim=code_dim, |
|
num_quantizers=codebook_num, |
|
codebook_size=codebook_size |
|
) |
|
|
|
def initial(self): |
|
self.codebook.initial() |
|
|
|
def forward(self, z): |
|
zq, vqloss, perplexity = self.codebook(z.transpose(2, 1)) |
|
zq = zq.transpose(2, 1) |
|
return zq, vqloss, perplexity |
|
|
|
def inference(self, z): |
|
zq, indices = self.codebook.forward_index(z.transpose(2, 1)) |
|
zq = zq.transpose(2, 1) |
|
return zq, indices |
|
|
|
def encode(self, z): |
|
zq, indices = self.codebook.forward_index(z.transpose(2, 1), flatten_idx=True) |
|
return zq, indices |
|
|
|
def decode(self, indices): |
|
z = self.codebook.lookup(indices) |
|
return z |
|
|