LayerAnimate / lvdm /models /autoencoder.py
YuxueYang
Upload demo
2a59fa8
import os
from functools import partial
from dataclasses import dataclass
import torch
import numpy as np
from einops import rearrange
import torch.nn.functional as F
from torch.utils.checkpoint import checkpoint
from diffusers.models.autoencoders.vae import DiagonalGaussianDistribution
from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.models import ModelMixin
from diffusers.utils import BaseOutput
from ..modules.ae_modules import Encoder, Decoder
from ..modules.ae_dualref_modules import VideoDecoder
from ..utils import instantiate_from_config
@dataclass
class DecoderOutput(BaseOutput):
"""
Output of decoding method.
Args:
sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
Decoded output sample of the model. Output of the last layer of the model.
"""
sample: torch.FloatTensor
@dataclass
class AutoencoderKLOutput(BaseOutput):
"""
Output of AutoencoderKL encoding method.
Args:
latent_dist (`DiagonalGaussianDistribution`):
Encoded outputs of `Encoder` represented as the mean and logvar of `DiagonalGaussianDistribution`.
`DiagonalGaussianDistribution` allows for sampling latents from the distribution.
"""
latent_dist: "DiagonalGaussianDistribution"
class AutoencoderKL(ModelMixin, ConfigMixin):
@register_to_config
def __init__(self,
ddconfig,
embed_dim,
image_key="image",
input_dim=4,
use_checkpoint=False,
):
super().__init__()
self.image_key = image_key
self.encoder = Encoder(**ddconfig)
self.decoder = Decoder(**ddconfig)
assert ddconfig["double_z"]
self.quant_conv = torch.nn.Conv2d(2*ddconfig["z_channels"], 2*embed_dim, 1)
self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
self.embed_dim = embed_dim
self.input_dim = input_dim
self.use_checkpoint = use_checkpoint
def encode(self, x, return_hidden_states=False, **kwargs):
if return_hidden_states:
h, hidden = self.encoder(x, return_hidden_states)
moments = self.quant_conv(h)
posterior = DiagonalGaussianDistribution(moments)
return AutoencoderKLOutput(latent_dist=posterior), hidden
else:
h = self.encoder(x)
moments = self.quant_conv(h)
posterior = DiagonalGaussianDistribution(moments)
return AutoencoderKLOutput(latent_dist=posterior)
def decode(self, z, **kwargs):
if len(kwargs) == 0: ## use the original decoder in AutoencoderKL
z = self.post_quant_conv(z)
dec = self.decoder(z, **kwargs) ##change for SVD decoder by adding **kwargs
return dec
def forward(self, input, sample_posterior=True, **additional_decode_kwargs):
input_tuple = (input, )
forward_temp = partial(self._forward, sample_posterior=sample_posterior, **additional_decode_kwargs)
return checkpoint(forward_temp, input_tuple, self.parameters(), self.use_checkpoint)
def _forward(self, input, sample_posterior=True, **additional_decode_kwargs):
posterior = self.encode(input)[0]
if sample_posterior:
z = posterior.sample()
else:
z = posterior.mode()
dec = self.decode(z, **additional_decode_kwargs)
## print(input.shape, dec.shape) torch.Size([16, 3, 256, 256]) torch.Size([16, 3, 256, 256])
return dec, posterior
def get_input(self, batch, k):
x = batch[k]
if x.dim() == 5 and self.input_dim == 4:
b,c,t,h,w = x.shape
self.b = b
self.t = t
x = rearrange(x, 'b c t h w -> (b t) c h w')
return x
def get_last_layer(self):
return self.decoder.conv_out.weight
class AutoencoderKL_Dualref(AutoencoderKL):
@register_to_config
def __init__(self,
ddconfig,
embed_dim,
image_key="image",
input_dim=4,
use_checkpoint=False,
):
super().__init__(ddconfig, embed_dim, image_key, input_dim, use_checkpoint)
self.decoder = VideoDecoder(**ddconfig)
def _forward(self, input, batch_size, sample_posterior=True, **additional_decode_kwargs):
posterior, hidden_states = self.encode(input, return_hidden_states=True)
hidden_states_first_last = []
### use only the first and last hidden states
for hid in hidden_states:
hid = rearrange(hid, '(b t) c h w -> b c t h w', b=batch_size)
hid_new = torch.cat([hid[:, :, 0:1], hid[:, :, -1:]], dim=2)
hidden_states_first_last.append(hid_new)
if sample_posterior:
z = posterior[0].sample()
else:
z = posterior[0].mode()
dec = self.decode(z, ref_context=hidden_states_first_last, **additional_decode_kwargs)
## print(input.shape, dec.shape) torch.Size([16, 3, 256, 256]) torch.Size([16, 3, 256, 256])
return dec, posterior