|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import torch.nn as nn |
|
|
|
from huggingface_hub import PyTorchModelHubMixin |
|
|
|
from optvq.models.backbone.diffusion import Encoder, Decoder |
|
from optvq.models.quantizer import VectorQuantizer, VectorQuantizerSinkhorn |
|
from optvq.losses.aeloss_disc import AELossWithDisc |
|
from optvq.models.vqgan import VQModel |
|
|
|
class VQModelHF(nn.Module, PyTorchModelHubMixin): |
|
def __init__(self, |
|
encoder: dict = {}, |
|
decoder: dict = {}, |
|
loss: dict = {}, |
|
quantize: dict = {}, |
|
quantize_type: str = "optvq", |
|
ckpt_path: str = None, |
|
ignore_keys=[], |
|
image_key="image", |
|
colorize_nlabels=None, |
|
monitor=None, |
|
use_connector: bool = True, |
|
): |
|
super(VQModelHF, self).__init__() |
|
encoder = Encoder(**encoder) |
|
decoder = Decoder(**decoder) |
|
quantizer = self.setup_quantizer(quantize, quantize_type) |
|
loss = AELossWithDisc(**loss) |
|
|
|
self.model = VQModel( |
|
encoder=encoder, |
|
decoder=decoder, |
|
loss=loss, |
|
quantize=quantizer, |
|
ckpt_path=ckpt_path, |
|
ignore_keys=ignore_keys, |
|
image_key=image_key, |
|
colorize_nlabels=colorize_nlabels, |
|
monitor=monitor, |
|
use_connector=use_connector, |
|
) |
|
|
|
def setup_quantizer(self, quantizer_config, quantize_type): |
|
if quantize_type == "optvq": |
|
quantizer = VectorQuantizerSinkhorn(**quantizer_config) |
|
elif quantize_type == "basevq": |
|
quantizer = VectorQuantizer(**quantizer_config) |
|
else: |
|
raise ValueError(f"Unknown quantizer type: {quantize_type}") |
|
return quantizer |
|
|
|
def encode(self, x): |
|
return self.model.encode(x) |
|
|
|
def decode(self, x): |
|
return self.model.decode(x) |
|
|
|
def forward(self, x): |
|
quant, *_ = self.encode(x) |
|
rec = self.decode(quant) |
|
return quant, rec |