Spaces:
Running
on
Zero
Running
on
Zero
import torch | |
import einops | |
import torch.nn as nn | |
import numpy as np | |
import pytorch_lightning as pl | |
from torch.optim.lr_scheduler import LambdaLR | |
from einops import rearrange, repeat | |
from functools import partial | |
from torchvision.utils import make_grid | |
from ldm.util import default, count_params, instantiate_from_config | |
from ldm.modules.distributions.distributions import DiagonalGaussianDistribution | |
from ldm.modules.diffusionmodules.util import make_beta_schedule, extract_into_tensor | |
from ldm.models.diffusion.ddim import DDIMSampler | |
from torchvision.transforms import Resize | |
import random | |
def disabled_train(self, mode=True): | |
return self | |
class DiffusionWrapper(pl.LightningModule): | |
def __init__(self, unet_config): | |
super().__init__() | |
self.diffusion_model = instantiate_from_config(unet_config) | |
def forward(self, x, timesteps=None, context=None, control=None): | |
out = self.diffusion_model(x, timesteps, context, control) | |
return out | |
class DDPM(pl.LightningModule): | |
def __init__(self, | |
unet_config, | |
linear_start=1e-4, # 0.00085 | |
linear_end=2e-2, # 0.0120 | |
log_every_t=100, # 200 | |
timesteps=1000, # 1000 | |
image_size=256, # 32 | |
channels=3, # 4 | |
u_cond_percent=0, # 0.2 | |
use_ema=True, # False | |
beta_schedule="linear", | |
loss_type="l2", | |
clip_denoised=True, | |
cosine_s=8e-3, | |
original_elbo_weight=0., | |
v_posterior=0., | |
l_simple_weight=1., | |
parameterization="eps", | |
use_positional_encodings=False, | |
learn_logvar=False, | |
logvar_init=0.): | |
super().__init__() | |
self.parameterization = parameterization | |
self.cond_stage_model = None | |
self.clip_denoised = clip_denoised | |
self.log_every_t = log_every_t | |
self.image_size = image_size | |
self.channels = channels | |
self.u_cond_percent=u_cond_percent | |
self.use_positional_encodings = use_positional_encodings | |
self.model = DiffusionWrapper(unet_config) # 调用 UNet 模型 | |
self.use_ema = use_ema | |
self.use_scheduler = True | |
self.v_posterior = v_posterior | |
self.original_elbo_weight = original_elbo_weight | |
self.l_simple_weight = l_simple_weight | |
self.register_schedule(beta_schedule=beta_schedule, # "linear" | |
timesteps=timesteps, # 1000 | |
linear_start=linear_start, # 0.00085 | |
linear_end=linear_end, # 0.0120 | |
cosine_s=cosine_s) # 8e-3 | |
self.loss_type = loss_type | |
self.learn_logvar = learn_logvar | |
self.logvar = torch.full(fill_value=logvar_init, size=(self.num_timesteps,)) | |
def register_schedule(self, | |
beta_schedule="linear", | |
timesteps=1000, | |
linear_start=0.00085, | |
linear_end=0.0120, | |
cosine_s=8e-3): | |
betas = make_beta_schedule(beta_schedule, | |
timesteps, | |
linear_start=linear_start, | |
linear_end=linear_end, | |
cosine_s=cosine_s) | |
alphas = 1. - betas | |
alphas_cumprod = np.cumprod(alphas, axis=0) | |
alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1]) | |
timesteps, = betas.shape | |
self.num_timesteps = int(timesteps) | |
self.linear_start = linear_start | |
self.linear_end = linear_end | |
to_torch = partial(torch.tensor, dtype=torch.float32) | |
self.register_buffer('betas', to_torch(betas)) | |
self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod)) | |
self.register_buffer('alphas_cumprod_prev', to_torch(alphas_cumprod_prev)) | |
self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod))) | |
self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod))) | |
self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod))) | |
self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod))) | |
self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod - 1))) | |
posterior_variance = (1 - self.v_posterior) * betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod) + self.v_posterior * betas | |
self.register_buffer('posterior_variance', to_torch(posterior_variance)) | |
self.register_buffer('posterior_log_variance_clipped', to_torch(np.log(np.maximum(posterior_variance, 1e-20)))) | |
self.register_buffer('posterior_mean_coef1', to_torch(betas * np.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod))) | |
self.register_buffer('posterior_mean_coef2', to_torch((1. - alphas_cumprod_prev) * np.sqrt(alphas) / (1. - alphas_cumprod))) | |
lvlb_weights = self.betas ** 2 / (2 * self.posterior_variance * to_torch(alphas) * (1 - self.alphas_cumprod)) | |
lvlb_weights[0] = lvlb_weights[1] | |
self.register_buffer('lvlb_weights', lvlb_weights, persistent=False) | |
def get_input(self, batch): | |
x = batch['GT'] | |
mask = batch['inpaint_mask'] | |
inpaint = batch['inpaint_image'] | |
reference = batch['ref_imgs'] | |
hint = batch['hint'] | |
x = x.to(memory_format=torch.contiguous_format).float() | |
mask = mask.to(memory_format=torch.contiguous_format).float() | |
inpaint = inpaint.to(memory_format=torch.contiguous_format).float() | |
reference = reference.to(memory_format=torch.contiguous_format).float() | |
hint = hint.to(memory_format=torch.contiguous_format).float() | |
return x, inpaint, mask, reference, hint | |
def q_sample(self, x_start, t, noise=None): | |
noise = default(noise, lambda: torch.randn_like(x_start)) | |
return (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + | |
extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise) | |
def get_loss(self, pred, target, mean=True): | |
if mean: | |
loss = torch.nn.functional.mse_loss(target, pred) | |
else: | |
loss = torch.nn.functional.mse_loss(target, pred, reduction='none') | |
return loss | |