|
from typing import List
|
|
from functools import partial
|
|
|
|
import numpy as np
|
|
import torch
|
|
import torch.nn as nn
|
|
|
|
from .modules.diffusionmodules.util import (
|
|
make_beta_schedule,
|
|
extract_into_tensor,
|
|
enforce_zero_terminal_snr,
|
|
noise_like,
|
|
)
|
|
from .util import exists, default, instantiate_from_config
|
|
from .modules.distributions.distributions import DiagonalGaussianDistribution
|
|
|
|
|
|
class DiffusionWrapper(nn.Module):
|
|
def __init__(self, diffusion_model):
|
|
super().__init__()
|
|
self.diffusion_model = diffusion_model
|
|
|
|
def forward(self, *args, **kwargs):
|
|
return self.diffusion_model(*args, **kwargs)
|
|
|
|
|
|
class LatentDiffusionInterface(nn.Module):
|
|
"""a simple interface class for LDM inference"""
|
|
|
|
def __init__(
|
|
self,
|
|
unet_config,
|
|
clip_config,
|
|
vae_config,
|
|
parameterization="eps",
|
|
scale_factor=0.18215,
|
|
beta_schedule="linear",
|
|
timesteps=1000,
|
|
linear_start=0.00085,
|
|
linear_end=0.0120,
|
|
cosine_s=8e-3,
|
|
given_betas=None,
|
|
zero_snr=False,
|
|
*args,
|
|
**kwargs,
|
|
):
|
|
super().__init__()
|
|
|
|
unet = instantiate_from_config(unet_config)
|
|
self.model = DiffusionWrapper(unet)
|
|
self.clip_model = instantiate_from_config(clip_config)
|
|
self.vae_model = instantiate_from_config(vae_config)
|
|
|
|
self.parameterization = parameterization
|
|
self.scale_factor = scale_factor
|
|
self.register_schedule(
|
|
given_betas=given_betas,
|
|
beta_schedule=beta_schedule,
|
|
timesteps=timesteps,
|
|
linear_start=linear_start,
|
|
linear_end=linear_end,
|
|
cosine_s=cosine_s,
|
|
zero_snr=zero_snr
|
|
)
|
|
|
|
def register_schedule(
|
|
self,
|
|
given_betas=None,
|
|
beta_schedule="linear",
|
|
timesteps=1000,
|
|
linear_start=1e-4,
|
|
linear_end=2e-2,
|
|
cosine_s=8e-3,
|
|
zero_snr=False
|
|
):
|
|
if exists(given_betas):
|
|
betas = given_betas
|
|
else:
|
|
betas = make_beta_schedule(
|
|
beta_schedule,
|
|
timesteps,
|
|
linear_start=linear_start,
|
|
linear_end=linear_end,
|
|
cosine_s=cosine_s,
|
|
)
|
|
if zero_snr:
|
|
print("--- using zero snr---")
|
|
betas = enforce_zero_terminal_snr(betas).numpy()
|
|
alphas = 1.0 - betas
|
|
alphas_cumprod = np.cumprod(alphas, axis=0)
|
|
alphas_cumprod_prev = np.append(1.0, alphas_cumprod[:-1])
|
|
|
|
(timesteps,) = betas.shape
|
|
self.num_timesteps = int(timesteps)
|
|
self.linear_start = linear_start
|
|
self.linear_end = linear_end
|
|
assert (
|
|
alphas_cumprod.shape[0] == self.num_timesteps
|
|
), "alphas have to be defined for each timestep"
|
|
|
|
to_torch = partial(torch.tensor, dtype=torch.float32)
|
|
|
|
self.register_buffer("betas", to_torch(betas))
|
|
self.register_buffer("alphas_cumprod", to_torch(alphas_cumprod))
|
|
self.register_buffer("alphas_cumprod_prev", to_torch(alphas_cumprod_prev))
|
|
|
|
|
|
self.register_buffer("sqrt_alphas_cumprod", to_torch(np.sqrt(alphas_cumprod)))
|
|
self.register_buffer(
|
|
"sqrt_one_minus_alphas_cumprod", to_torch(np.sqrt(1.0 - alphas_cumprod))
|
|
)
|
|
self.register_buffer(
|
|
"log_one_minus_alphas_cumprod", to_torch(np.log(1.0 - alphas_cumprod))
|
|
)
|
|
eps = 1e-8
|
|
self.register_buffer(
|
|
"sqrt_recip_alphas_cumprod", to_torch(np.sqrt(1.0 / (alphas_cumprod + eps)))
|
|
)
|
|
self.register_buffer(
|
|
"sqrt_recipm1_alphas_cumprod", to_torch(np.sqrt(1.0 / (alphas_cumprod + eps) - 1))
|
|
)
|
|
|
|
|
|
self.v_posterior = 0
|
|
posterior_variance = (1 - self.v_posterior) * betas * (
|
|
1.0 - alphas_cumprod_prev
|
|
) / (1.0 - alphas_cumprod) + self.v_posterior * betas
|
|
|
|
self.register_buffer("posterior_variance", to_torch(posterior_variance))
|
|
|
|
self.register_buffer(
|
|
"posterior_log_variance_clipped",
|
|
to_torch(np.log(np.maximum(posterior_variance, 1e-20))),
|
|
)
|
|
self.register_buffer(
|
|
"posterior_mean_coef1",
|
|
to_torch(betas * np.sqrt(alphas_cumprod_prev) / (1.0 - alphas_cumprod)),
|
|
)
|
|
self.register_buffer(
|
|
"posterior_mean_coef2",
|
|
to_torch(
|
|
(1.0 - alphas_cumprod_prev) * np.sqrt(alphas) / (1.0 - alphas_cumprod)
|
|
),
|
|
)
|
|
|
|
def q_sample(self, x_start, t, noise=None):
|
|
noise = default(noise, lambda: torch.randn_like(x_start))
|
|
return (
|
|
extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
|
|
+ extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape)
|
|
* noise
|
|
)
|
|
|
|
def get_v(self, x, noise, t):
|
|
return (
|
|
extract_into_tensor(self.sqrt_alphas_cumprod, t, x.shape) * noise
|
|
- extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x.shape) * x
|
|
)
|
|
|
|
def predict_start_from_noise(self, x_t, t, noise):
|
|
return (
|
|
extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t
|
|
- extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape)
|
|
* noise
|
|
)
|
|
|
|
def predict_start_from_z_and_v(self, x_t, t, v):
|
|
return (
|
|
extract_into_tensor(self.sqrt_alphas_cumprod, t, x_t.shape) * x_t
|
|
- extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_t.shape) * v
|
|
)
|
|
|
|
def predict_eps_from_z_and_v(self, x_t, t, v):
|
|
return (
|
|
extract_into_tensor(self.sqrt_alphas_cumprod, t, x_t.shape) * v
|
|
+ extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_t.shape)
|
|
* x_t
|
|
)
|
|
|
|
def apply_model(self, x_noisy, t, cond, **kwargs):
|
|
assert isinstance(cond, dict), "cond has to be a dictionary"
|
|
return self.model(x_noisy, t, **cond, **kwargs)
|
|
|
|
def get_learned_conditioning(self, prompts: List[str]):
|
|
return self.clip_model(prompts)
|
|
|
|
def get_learned_image_conditioning(self, images):
|
|
return self.clip_model.forward_image(images)
|
|
|
|
def get_first_stage_encoding(self, encoder_posterior):
|
|
if isinstance(encoder_posterior, DiagonalGaussianDistribution):
|
|
z = encoder_posterior.sample()
|
|
elif isinstance(encoder_posterior, torch.Tensor):
|
|
z = encoder_posterior
|
|
else:
|
|
raise NotImplementedError(
|
|
f"encoder_posterior of type '{type(encoder_posterior)}' not yet implemented"
|
|
)
|
|
return self.scale_factor * z
|
|
|
|
def encode_first_stage(self, x):
|
|
return self.vae_model.encode(x)
|
|
|
|
def decode_first_stage(self, z):
|
|
z = 1.0 / self.scale_factor * z
|
|
return self.vae_model.decode(z)
|
|
|