|
import torch
|
|
import einops
|
|
import torch.nn as nn
|
|
import numpy as np
|
|
import pytorch_lightning as pl
|
|
from torch.optim.lr_scheduler import LambdaLR
|
|
from einops import rearrange, repeat
|
|
from functools import partial
|
|
from torchvision.utils import make_grid
|
|
from ldm.util import default, count_params, instantiate_from_config
|
|
from ldm.modules.distributions.distributions import DiagonalGaussianDistribution
|
|
from ldm.modules.diffusionmodules.util import make_beta_schedule, extract_into_tensor
|
|
from ldm.models.diffusion.ddim import DDIMSampler
|
|
from torchvision.transforms import Resize
|
|
import random
|
|
|
|
|
|
|
|
def disabled_train(self, mode=True):
|
|
return self
|
|
|
|
|
|
class DiffusionWrapper(pl.LightningModule):
|
|
def __init__(self, unet_config):
|
|
super().__init__()
|
|
self.diffusion_model = instantiate_from_config(unet_config)
|
|
|
|
def forward(self, x, timesteps=None, context=None, control=None):
|
|
out = self.diffusion_model(x, timesteps, context, control)
|
|
return out
|
|
|
|
|
|
class DDPM(pl.LightningModule):
|
|
def __init__(self,
|
|
unet_config,
|
|
linear_start=1e-4,
|
|
linear_end=2e-2,
|
|
log_every_t=100,
|
|
timesteps=1000,
|
|
image_size=256,
|
|
channels=3,
|
|
u_cond_percent=0,
|
|
use_ema=True,
|
|
beta_schedule="linear",
|
|
loss_type="l2",
|
|
clip_denoised=True,
|
|
cosine_s=8e-3,
|
|
original_elbo_weight=0.,
|
|
v_posterior=0.,
|
|
l_simple_weight=1.,
|
|
parameterization="eps",
|
|
use_positional_encodings=False,
|
|
learn_logvar=False,
|
|
logvar_init=0.):
|
|
super().__init__()
|
|
self.parameterization = parameterization
|
|
self.cond_stage_model = None
|
|
self.clip_denoised = clip_denoised
|
|
self.log_every_t = log_every_t
|
|
self.image_size = image_size
|
|
self.channels = channels
|
|
self.u_cond_percent=u_cond_percent
|
|
self.use_positional_encodings = use_positional_encodings
|
|
self.model = DiffusionWrapper(unet_config)
|
|
|
|
self.use_ema = use_ema
|
|
self.use_scheduler = True
|
|
self.v_posterior = v_posterior
|
|
self.original_elbo_weight = original_elbo_weight
|
|
self.l_simple_weight = l_simple_weight
|
|
self.register_schedule(beta_schedule=beta_schedule,
|
|
timesteps=timesteps,
|
|
linear_start=linear_start,
|
|
linear_end=linear_end,
|
|
cosine_s=cosine_s)
|
|
self.loss_type = loss_type
|
|
self.learn_logvar = learn_logvar
|
|
self.logvar = torch.full(fill_value=logvar_init, size=(self.num_timesteps,))
|
|
|
|
def register_schedule(self,
|
|
beta_schedule="linear",
|
|
timesteps=1000,
|
|
linear_start=0.00085,
|
|
linear_end=0.0120,
|
|
cosine_s=8e-3):
|
|
betas = make_beta_schedule(beta_schedule,
|
|
timesteps,
|
|
linear_start=linear_start,
|
|
linear_end=linear_end,
|
|
cosine_s=cosine_s)
|
|
alphas = 1. - betas
|
|
alphas_cumprod = np.cumprod(alphas, axis=0)
|
|
alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1])
|
|
|
|
timesteps, = betas.shape
|
|
self.num_timesteps = int(timesteps)
|
|
self.linear_start = linear_start
|
|
self.linear_end = linear_end
|
|
to_torch = partial(torch.tensor, dtype=torch.float32)
|
|
self.register_buffer('betas', to_torch(betas))
|
|
self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
|
|
self.register_buffer('alphas_cumprod_prev', to_torch(alphas_cumprod_prev))
|
|
self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod)))
|
|
self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod)))
|
|
self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod)))
|
|
self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod)))
|
|
self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod - 1)))
|
|
posterior_variance = (1 - self.v_posterior) * betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod) + self.v_posterior * betas
|
|
self.register_buffer('posterior_variance', to_torch(posterior_variance))
|
|
self.register_buffer('posterior_log_variance_clipped', to_torch(np.log(np.maximum(posterior_variance, 1e-20))))
|
|
self.register_buffer('posterior_mean_coef1', to_torch(betas * np.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod)))
|
|
self.register_buffer('posterior_mean_coef2', to_torch((1. - alphas_cumprod_prev) * np.sqrt(alphas) / (1. - alphas_cumprod)))
|
|
lvlb_weights = self.betas ** 2 / (2 * self.posterior_variance * to_torch(alphas) * (1 - self.alphas_cumprod))
|
|
lvlb_weights[0] = lvlb_weights[1]
|
|
self.register_buffer('lvlb_weights', lvlb_weights, persistent=False)
|
|
|
|
def get_input(self, batch):
|
|
x = batch['GT']
|
|
mask = batch['inpaint_mask']
|
|
inpaint = batch['inpaint_image']
|
|
reference = batch['ref_imgs']
|
|
hint = batch['hint']
|
|
|
|
x = x.to(memory_format=torch.contiguous_format).float()
|
|
mask = mask.to(memory_format=torch.contiguous_format).float()
|
|
inpaint = inpaint.to(memory_format=torch.contiguous_format).float()
|
|
reference = reference.to(memory_format=torch.contiguous_format).float()
|
|
hint = hint.to(memory_format=torch.contiguous_format).float()
|
|
|
|
return x, inpaint, mask, reference, hint
|
|
|
|
def q_sample(self, x_start, t, noise=None):
|
|
noise = default(noise, lambda: torch.randn_like(x_start))
|
|
return (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start +
|
|
extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise)
|
|
|
|
def get_loss(self, pred, target, mean=True):
|
|
if mean:
|
|
loss = torch.nn.functional.mse_loss(target, pred)
|
|
else:
|
|
loss = torch.nn.functional.mse_loss(target, pred, reduction='none')
|
|
return loss
|
|
|
|
|