import torch import einops import torch.nn as nn import numpy as np import pytorch_lightning as pl from torch.optim.lr_scheduler import LambdaLR from einops import rearrange, repeat from functools import partial from torchvision.utils import make_grid from ldm.util import default, count_params, instantiate_from_config from ldm.modules.distributions.distributions import DiagonalGaussianDistribution from ldm.modules.diffusionmodules.util import make_beta_schedule, extract_into_tensor from ldm.models.diffusion.ddim import DDIMSampler from torchvision.transforms import Resize import random def disabled_train(self, mode=True): return self class DiffusionWrapper(pl.LightningModule): def __init__(self, unet_config): super().__init__() self.diffusion_model = instantiate_from_config(unet_config) def forward(self, x, timesteps=None, context=None, control=None): out = self.diffusion_model(x, timesteps, context, control) return out class DDPM(pl.LightningModule): def __init__(self, unet_config, linear_start=1e-4, # 0.00085 linear_end=2e-2, # 0.0120 log_every_t=100, # 200 timesteps=1000, # 1000 image_size=256, # 32 channels=3, # 4 u_cond_percent=0, # 0.2 use_ema=True, # False beta_schedule="linear", loss_type="l2", clip_denoised=True, cosine_s=8e-3, original_elbo_weight=0., v_posterior=0., l_simple_weight=1., parameterization="eps", use_positional_encodings=False, learn_logvar=False, logvar_init=0.): super().__init__() self.parameterization = parameterization self.cond_stage_model = None self.clip_denoised = clip_denoised self.log_every_t = log_every_t self.image_size = image_size self.channels = channels self.u_cond_percent=u_cond_percent self.use_positional_encodings = use_positional_encodings self.model = DiffusionWrapper(unet_config) # 调用 UNet 模型 self.use_ema = use_ema self.use_scheduler = True self.v_posterior = v_posterior self.original_elbo_weight = original_elbo_weight self.l_simple_weight = l_simple_weight self.register_schedule(beta_schedule=beta_schedule, # "linear" timesteps=timesteps, # 1000 linear_start=linear_start, # 0.00085 linear_end=linear_end, # 0.0120 cosine_s=cosine_s) # 8e-3 self.loss_type = loss_type self.learn_logvar = learn_logvar self.logvar = torch.full(fill_value=logvar_init, size=(self.num_timesteps,)) def register_schedule(self, beta_schedule="linear", timesteps=1000, linear_start=0.00085, linear_end=0.0120, cosine_s=8e-3): betas = make_beta_schedule(beta_schedule, timesteps, linear_start=linear_start, linear_end=linear_end, cosine_s=cosine_s) alphas = 1. - betas alphas_cumprod = np.cumprod(alphas, axis=0) alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1]) timesteps, = betas.shape self.num_timesteps = int(timesteps) self.linear_start = linear_start self.linear_end = linear_end to_torch = partial(torch.tensor, dtype=torch.float32) self.register_buffer('betas', to_torch(betas)) self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod)) self.register_buffer('alphas_cumprod_prev', to_torch(alphas_cumprod_prev)) self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod))) self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod))) self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod))) self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod))) self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod - 1))) posterior_variance = (1 - self.v_posterior) * betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod) + self.v_posterior * betas self.register_buffer('posterior_variance', to_torch(posterior_variance)) self.register_buffer('posterior_log_variance_clipped', to_torch(np.log(np.maximum(posterior_variance, 1e-20)))) self.register_buffer('posterior_mean_coef1', to_torch(betas * np.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod))) self.register_buffer('posterior_mean_coef2', to_torch((1. - alphas_cumprod_prev) * np.sqrt(alphas) / (1. - alphas_cumprod))) lvlb_weights = self.betas ** 2 / (2 * self.posterior_variance * to_torch(alphas) * (1 - self.alphas_cumprod)) lvlb_weights[0] = lvlb_weights[1] self.register_buffer('lvlb_weights', lvlb_weights, persistent=False) def get_input(self, batch): x = batch['GT'] mask = batch['inpaint_mask'] inpaint = batch['inpaint_image'] reference = batch['ref_imgs'] hint = batch['hint'] x = x.to(memory_format=torch.contiguous_format).float() mask = mask.to(memory_format=torch.contiguous_format).float() inpaint = inpaint.to(memory_format=torch.contiguous_format).float() reference = reference.to(memory_format=torch.contiguous_format).float() hint = hint.to(memory_format=torch.contiguous_format).float() return x, inpaint, mask, reference, hint def q_sample(self, x_start, t, noise=None): noise = default(noise, lambda: torch.randn_like(x_start)) return (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise) def get_loss(self, pred, target, mean=True): if mean: loss = torch.nn.functional.mse_loss(target, pred) else: loss = torch.nn.functional.mse_loss(target, pred, reduction='none') return loss