|
import torch |
|
import torch.nn.functional as F |
|
from contextlib import contextmanager |
|
|
|
from ldm.modules.diffusionmodules.model import Encoder, Decoder |
|
from ldm.modules.distributions.distributions import DiagonalGaussianDistribution |
|
from ldm.modules.vqvae.quantize import VectorQuantizer2 as VectorQuantizer |
|
|
|
from ldm.util import instantiate_from_config |
|
from ldm.modules.ema import LitEma |
|
|
|
class VQModelTorch(torch.nn.Module): |
|
def __init__(self, |
|
ddconfig, |
|
n_embed, |
|
embed_dim, |
|
remap=None, |
|
sane_index_shape=False, |
|
): |
|
super().__init__() |
|
self.encoder = Encoder(**ddconfig) |
|
self.decoder = Decoder(**ddconfig) |
|
self.quantize = VectorQuantizer(n_embed, embed_dim, beta=0.25, |
|
remap=remap, sane_index_shape=sane_index_shape) |
|
self.quant_conv = torch.nn.Conv2d(ddconfig["z_channels"], embed_dim, 1) |
|
self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1) |
|
|
|
def encode(self, x): |
|
h = self.encoder(x) |
|
h = self.quant_conv(h) |
|
return h |
|
|
|
def decode(self, h, force_not_quantize=False): |
|
if not force_not_quantize: |
|
quant, emb_loss, info = self.quantize(h) |
|
else: |
|
quant = h |
|
quant = self.post_quant_conv(quant) |
|
dec = self.decoder(quant) |
|
return dec |
|
|
|
def decode_code(self, code_b): |
|
quant_b = self.quantize.embed_code(code_b) |
|
dec = self.decode(quant_b, force_not_quantize=True) |
|
return dec |
|
|
|
def forward(self, input, force_not_quantize=False): |
|
h = self.encode(input) |
|
dec = self.decode(h, force_not_quantize) |
|
return dec |
|
|
|
class AutoencoderKLTorch(torch.nn.Module): |
|
def __init__(self, |
|
ddconfig, |
|
embed_dim, |
|
): |
|
super().__init__() |
|
self.encoder = Encoder(**ddconfig) |
|
self.decoder = Decoder(**ddconfig) |
|
assert ddconfig["double_z"] |
|
self.quant_conv = torch.nn.Conv2d(2*ddconfig["z_channels"], 2*embed_dim, 1) |
|
self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1) |
|
self.embed_dim = embed_dim |
|
|
|
def encode(self, x, sample_posterior=True, return_moments=False): |
|
h = self.encoder(x) |
|
moments = self.quant_conv(h) |
|
posterior = DiagonalGaussianDistribution(moments) |
|
if sample_posterior: |
|
z = posterior.sample() |
|
else: |
|
z = posterior.mode() |
|
if return_moments: |
|
return z, moments |
|
else: |
|
return z |
|
|
|
def decode(self, z): |
|
z = self.post_quant_conv(z) |
|
dec = self.decoder(z) |
|
return dec |
|
|
|
def forward(self, input, sample_posterior=True): |
|
z = self.encode(input, sample_posterior, return_moments=False) |
|
dec = self.decode(z) |
|
return dec |
|
|
|
class EncoderKLTorch(torch.nn.Module): |
|
def __init__(self, |
|
ddconfig, |
|
embed_dim, |
|
): |
|
super().__init__() |
|
self.encoder = Encoder(**ddconfig) |
|
assert ddconfig["double_z"] |
|
self.quant_conv = torch.nn.Conv2d(2*ddconfig["z_channels"], 2*embed_dim, 1) |
|
self.embed_dim = embed_dim |
|
|
|
def encode(self, x, sample_posterior=True, return_moments=False): |
|
h = self.encoder(x) |
|
moments = self.quant_conv(h) |
|
posterior = DiagonalGaussianDistribution(moments) |
|
if sample_posterior: |
|
z = posterior.sample() |
|
else: |
|
z = posterior.mode() |
|
if return_moments: |
|
return z, moments |
|
else: |
|
return z |
|
def forward(self, x, sample_posterior=True, return_moments=False): |
|
return self.encode(x, sample_posterior, return_moments) |
|
|
|
class IdentityFirstStage(torch.nn.Module): |
|
def __init__(self, *args, vq_interface=False, **kwargs): |
|
self.vq_interface = vq_interface |
|
super().__init__() |
|
|
|
def encode(self, x, *args, **kwargs): |
|
return x |
|
|
|
def decode(self, x, *args, **kwargs): |
|
return x |
|
|
|
def quantize(self, x, *args, **kwargs): |
|
if self.vq_interface: |
|
return x, None, [None, None, None] |
|
return x |
|
|
|
def forward(self, x, *args, **kwargs): |
|
return x |
|
|
|
|