|
import json |
|
|
|
import torch |
|
import torch.nn as nn |
|
|
|
from Preprocessing.Codec.env import AttrDict |
|
from Preprocessing.Codec.models import Encoder |
|
from Preprocessing.Codec.models import Generator |
|
from Preprocessing.Codec.models import Quantizer |
|
|
|
|
|
class VQVAE(nn.Module): |
|
def __init__(self, |
|
config_path, |
|
ckpt_path, |
|
with_encoder=False): |
|
super(VQVAE, self).__init__() |
|
ckpt = torch.load(ckpt_path, map_location=torch.device('cpu')) |
|
with open(config_path) as f: |
|
data = f.read() |
|
json_config = json.loads(data) |
|
self.h = AttrDict(json_config) |
|
self.quantizer = Quantizer(self.h) |
|
self.generator = Generator(self.h) |
|
self.generator.load_state_dict(ckpt['generator']) |
|
self.quantizer.load_state_dict(ckpt['quantizer']) |
|
if with_encoder: |
|
self.encoder = Encoder(self.h) |
|
self.encoder.load_state_dict(ckpt['encoder']) |
|
|
|
def forward(self, x): |
|
|
|
|
|
quant_emb = self.quantizer.embed(x) |
|
return self.generator(quant_emb) |
|
|
|
def encode(self, x): |
|
batch_size = x.size(0) |
|
if len(x.shape) == 3 and x.shape[-1] == 1: |
|
x = x.squeeze(-1) |
|
c = self.encoder(x.unsqueeze(1)) |
|
q, loss_q, c = self.quantizer(c) |
|
c = [code.reshape(batch_size, -1) for code in c] |
|
|
|
return torch.stack(c, -1) |
|
|