OptVQ / optvq /models /vqgan_hf.py
BorelTHU's picture
initiate
223d932
# ------------------------------------------------------------------------------
# OptVQ: Preventing Local Pitfalls in Vector Quantization via Optimal Transport
# Copyright (c) 2024 Borui Zhang. All Rights Reserved.
# Licensed under the MIT License [see LICENSE for details]
# ------------------------------------------------------------------------------
# Convert a Pytorch model to a Hugging Face model
import torch.nn as nn
from huggingface_hub import PyTorchModelHubMixin
from optvq.models.backbone.diffusion import Encoder, Decoder
from optvq.models.quantizer import VectorQuantizer, VectorQuantizerSinkhorn
from optvq.losses.aeloss_disc import AELossWithDisc
from optvq.models.vqgan import VQModel
class VQModelHF(nn.Module, PyTorchModelHubMixin):
def __init__(self,
encoder: dict = {},
decoder: dict = {},
loss: dict = {},
quantize: dict = {},
quantize_type: str = "optvq",
ckpt_path: str = None,
ignore_keys=[],
image_key="image",
colorize_nlabels=None,
monitor=None,
use_connector: bool = True,
):
super(VQModelHF, self).__init__()
encoder = Encoder(**encoder)
decoder = Decoder(**decoder)
quantizer = self.setup_quantizer(quantize, quantize_type)
loss = AELossWithDisc(**loss)
self.model = VQModel(
encoder=encoder,
decoder=decoder,
loss=loss,
quantize=quantizer,
ckpt_path=ckpt_path,
ignore_keys=ignore_keys,
image_key=image_key,
colorize_nlabels=colorize_nlabels,
monitor=monitor,
use_connector=use_connector,
)
def setup_quantizer(self, quantizer_config, quantize_type):
if quantize_type == "optvq":
quantizer = VectorQuantizerSinkhorn(**quantizer_config)
elif quantize_type == "basevq":
quantizer = VectorQuantizer(**quantizer_config)
else:
raise ValueError(f"Unknown quantizer type: {quantize_type}")
return quantizer
def encode(self, x):
return self.model.encode(x)
def decode(self, x):
return self.model.decode(x)
def forward(self, x):
quant, *_ = self.encode(x)
rec = self.decode(quant)
return quant, rec