resshift / ldm /models /autoencoder.py
yuhj95's picture
Upload folder using huggingface_hub
4730cdc verified
raw
history blame
4.37 kB
import torch
import torch.nn.functional as F
from contextlib import contextmanager
from ldm.modules.diffusionmodules.model import Encoder, Decoder
from ldm.modules.distributions.distributions import DiagonalGaussianDistribution
from ldm.modules.vqvae.quantize import VectorQuantizer2 as VectorQuantizer
from ldm.util import instantiate_from_config
from ldm.modules.ema import LitEma
class VQModelTorch(torch.nn.Module):
def __init__(self,
ddconfig,
n_embed,
embed_dim,
remap=None,
sane_index_shape=False, # tell vector quantizer to return indices as bhw
):
super().__init__()
self.encoder = Encoder(**ddconfig)
self.decoder = Decoder(**ddconfig)
self.quantize = VectorQuantizer(n_embed, embed_dim, beta=0.25,
remap=remap, sane_index_shape=sane_index_shape)
self.quant_conv = torch.nn.Conv2d(ddconfig["z_channels"], embed_dim, 1)
self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
def encode(self, x):
h = self.encoder(x)
h = self.quant_conv(h)
return h
def decode(self, h, force_not_quantize=False):
if not force_not_quantize:
quant, emb_loss, info = self.quantize(h)
else:
quant = h
quant = self.post_quant_conv(quant)
dec = self.decoder(quant)
return dec
def decode_code(self, code_b):
quant_b = self.quantize.embed_code(code_b)
dec = self.decode(quant_b, force_not_quantize=True)
return dec
def forward(self, input, force_not_quantize=False):
h = self.encode(input)
dec = self.decode(h, force_not_quantize)
return dec
class AutoencoderKLTorch(torch.nn.Module):
def __init__(self,
ddconfig,
embed_dim,
):
super().__init__()
self.encoder = Encoder(**ddconfig)
self.decoder = Decoder(**ddconfig)
assert ddconfig["double_z"]
self.quant_conv = torch.nn.Conv2d(2*ddconfig["z_channels"], 2*embed_dim, 1)
self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
self.embed_dim = embed_dim
def encode(self, x, sample_posterior=True, return_moments=False):
h = self.encoder(x)
moments = self.quant_conv(h)
posterior = DiagonalGaussianDistribution(moments)
if sample_posterior:
z = posterior.sample()
else:
z = posterior.mode()
if return_moments:
return z, moments
else:
return z
def decode(self, z):
z = self.post_quant_conv(z)
dec = self.decoder(z)
return dec
def forward(self, input, sample_posterior=True):
z = self.encode(input, sample_posterior, return_moments=False)
dec = self.decode(z)
return dec
class EncoderKLTorch(torch.nn.Module):
def __init__(self,
ddconfig,
embed_dim,
):
super().__init__()
self.encoder = Encoder(**ddconfig)
assert ddconfig["double_z"]
self.quant_conv = torch.nn.Conv2d(2*ddconfig["z_channels"], 2*embed_dim, 1)
self.embed_dim = embed_dim
def encode(self, x, sample_posterior=True, return_moments=False):
h = self.encoder(x)
moments = self.quant_conv(h)
posterior = DiagonalGaussianDistribution(moments)
if sample_posterior:
z = posterior.sample()
else:
z = posterior.mode()
if return_moments:
return z, moments
else:
return z
def forward(self, x, sample_posterior=True, return_moments=False):
return self.encode(x, sample_posterior, return_moments)
class IdentityFirstStage(torch.nn.Module):
def __init__(self, *args, vq_interface=False, **kwargs):
self.vq_interface = vq_interface
super().__init__()
def encode(self, x, *args, **kwargs):
return x
def decode(self, x, *args, **kwargs):
return x
def quantize(self, x, *args, **kwargs):
if self.vq_interface:
return x, None, [None, None, None]
return x
def forward(self, x, *args, **kwargs):
return x