import os, math, random, argparse, logging from pathlib import Path from typing import Optional, Union, List, Callable from collections import OrderedDict from packaging import version from tqdm.auto import tqdm from omegaconf import OmegaConf import numpy as np import torch import torch.nn.functional as F import torch.utils.checkpoint import torchvision class PFODESolver(): def __init__(self, scheduler, t_initial=1, t_terminal=0,) -> None: self.t_initial = t_initial self.t_terminal = t_terminal self.scheduler = scheduler train_step_terminal = 0 train_step_initial = train_step_terminal + self.scheduler.config.num_train_timesteps # 0+1000 self.stepsize = (t_terminal-t_initial) / (train_step_terminal - train_step_initial) #1/1000 def get_timesteps(self, t_start, t_end, num_steps): # (b,) -> (b,1) t_start = t_start[:, None] t_end = t_end[:, None] assert t_start.dim() == 2 timepoints = torch.arange(0, num_steps, 1).expand(t_start.shape[0], num_steps).to(device=t_start.device) interval = (t_end - t_start) / (torch.ones([1], device=t_start.device) * num_steps) timepoints = t_start + interval * timepoints timesteps = (self.scheduler.num_train_timesteps - 1) + (timepoints - self.t_initial) / self.stepsize # correspondint to StableDiffusion indexing system, from 999 (t_init) -> 0 (dt) return timesteps.round().long() # return timesteps.floor().long() def solve(self, latents, unet, t_start, t_end, prompt_embeds, negative_prompt_embeds, guidance_scale=1.0, num_steps = 2, num_windows = 1, ): assert t_start.dim() == 1 assert guidance_scale >= 1 and torch.all(torch.gt(t_start, t_end)) do_classifier_free_guidance = True if guidance_scale > 1 else False bsz = latents.shape[0] if do_classifier_free_guidance: prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) timestep_cond = None if unet.config.time_cond_proj_dim is not None: guidance_scale_tensor = torch.tensor(guidance_scale - 1).repeat(bsz) timestep_cond = self.get_guidance_scale_embedding( guidance_scale_tensor, embedding_dim=unet.config.time_cond_proj_dim ).to(device=latents.device, dtype=latents.dtype) timesteps = self.get_timesteps(t_start, t_end, num_steps).to(device=latents.device) timestep_interval = self.scheduler.config.num_train_timesteps // (num_windows * num_steps) # 7. Denoising loop with torch.no_grad(): # for i in tqdm(range(num_steps)): for i in range(num_steps): t = torch.cat([timesteps[:, i]]*2) if do_classifier_free_guidance else timesteps[:, i] # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) # predict the noise residual noise_pred = unet( latent_model_input, t, encoder_hidden_states=prompt_embeds, timestep_cond=timestep_cond, return_dict=False, )[0] if do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) # STEP: compute the previous noisy sample x_t -> x_t-1 # latents = self.scheduler.step(noise_pred, timesteps[:, i].cpu(), latents, return_dict=False)[0] batch_timesteps = timesteps[:, i].cpu() prev_timestep = batch_timesteps - timestep_interval # prev_timestep = batch_timesteps - self.scheduler.config.num_train_timesteps // self.scheduler.num_inference_steps alpha_prod_t = self.scheduler.alphas_cumprod[batch_timesteps] alpha_prod_t_prev = torch.zeros_like(alpha_prod_t) for ib in range(prev_timestep.shape[0]): alpha_prod_t_prev[ib] = self.scheduler.alphas_cumprod[prev_timestep[ib]] if prev_timestep[ib] >= 0 else self.scheduler.final_alpha_cumprod beta_prod_t = 1 - alpha_prod_t alpha_prod_t = alpha_prod_t.to(device=latents.device, dtype=latents.dtype) alpha_prod_t_prev = alpha_prod_t_prev.to(device=latents.device, dtype=latents.dtype) beta_prod_t = beta_prod_t.to(device=latents.device, dtype=latents.dtype) # 3. compute predicted original sample from predicted noise also called # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf if self.scheduler.config.prediction_type == "epsilon": pred_original_sample = (latents - beta_prod_t[:,None,None,None] ** (0.5) * noise_pred) / alpha_prod_t[:, None,None,None] ** (0.5) pred_epsilon = noise_pred # elif self.scheduler.config.prediction_type == "sample": # pred_original_sample = noise_pred # pred_epsilon = (latents - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5) elif self.scheduler.config.prediction_type == "v_prediction": pred_original_sample = (alpha_prod_t[:,None,None,None]**0.5) * latents - (beta_prod_t[:,None,None,None]**0.5) * noise_pred pred_epsilon = (alpha_prod_t[:,None,None,None]**0.5) * noise_pred + (beta_prod_t[:,None,None,None]**0.5) * latents else: raise ValueError( f"prediction_type given as {self.scheduler.config.prediction_type} must be one of `epsilon`, `sample`, or" " `v_prediction`" ) pred_sample_direction = (1 - alpha_prod_t_prev[:,None,None,None]) ** (0.5) * pred_epsilon latents = alpha_prod_t_prev[:,None,None,None] ** (0.5) * pred_original_sample + pred_sample_direction return latents class PFODESolverSDXL(): def __init__(self, scheduler, t_initial=1, t_terminal=0,) -> None: self.t_initial = t_initial self.t_terminal = t_terminal self.scheduler = scheduler train_step_terminal = 0 train_step_initial = train_step_terminal + self.scheduler.config.num_train_timesteps # 0+1000 self.stepsize = (t_terminal-t_initial) / (train_step_terminal - train_step_initial) #1/1000 def get_timesteps(self, t_start, t_end, num_steps): # (b,) -> (b,1) t_start = t_start[:, None] t_end = t_end[:, None] assert t_start.dim() == 2 timepoints = torch.arange(0, num_steps, 1).expand(t_start.shape[0], num_steps).to(device=t_start.device) interval = (t_end - t_start) / (torch.ones([1], device=t_start.device) * num_steps) timepoints = t_start + interval * timepoints timesteps = (self.scheduler.num_train_timesteps - 1) + (timepoints - self.t_initial) / self.stepsize # correspondint to StableDiffusion indexing system, from 999 (t_init) -> 0 (dt) return timesteps.round().long() # return timesteps.floor().long() def _get_add_time_ids(self, original_size, crops_coords_top_left, target_size, dtype): # Adapted from pipeline.StableDiffusionXLPipeline._get_add_time_ids add_time_ids = list(original_size + crops_coords_top_left + target_size) add_time_ids = torch.tensor([add_time_ids], dtype=dtype) return add_time_ids def solve(self, latents, unet, t_start, t_end, prompt_embeds, pooled_prompt_embeds, negative_prompt_embeds, negative_pooled_prompt_embeds, guidance_scale=1.0, num_steps = 10, num_windows = 4, resolution = 1024, ): assert t_start.dim() == 1 assert guidance_scale >= 1 and torch.all(torch.gt(t_start, t_end)) dtype = latents.dtype device = latents.device bsz = latents.shape[0] do_classifier_free_guidance = True if guidance_scale > 1 else False add_text_embeds = pooled_prompt_embeds add_time_ids = torch.cat( # [self._get_add_time_ids((1024, 1024), (0, 0), (1024, 1024), dtype) for _ in range(bsz)] [self._get_add_time_ids((resolution, resolution), (0, 0), (resolution, resolution), dtype) for _ in range(bsz)] ).to(device) negative_add_time_ids = add_time_ids if do_classifier_free_guidance: # prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0) add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0) timestep_cond = None if unet.config.time_cond_proj_dim is not None: guidance_scale_tensor = torch.tensor(guidance_scale - 1).repeat(bsz) timestep_cond = self.get_guidance_scale_embedding( guidance_scale_tensor, embedding_dim=unet.config.time_cond_proj_dim ).to(device=latents.device, dtype=latents.dtype) timesteps = self.get_timesteps(t_start, t_end, num_steps).to(device=latents.device) timestep_interval = self.scheduler.config.num_train_timesteps // (num_windows * num_steps) # 7. Denoising loop with torch.no_grad(): # for i in tqdm(range(num_steps)): for i in range(num_steps): # expand the latents if we are doing classifier free guidance t = torch.cat([timesteps[:, i]]*2) if do_classifier_free_guidance else timesteps[:, i] latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) # predict the noise residual added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} noise_pred = unet( latent_model_input, t, encoder_hidden_states=prompt_embeds, timestep_cond=timestep_cond, added_cond_kwargs=added_cond_kwargs, return_dict=False, )[0] if do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) # STEP: compute the previous noisy sample x_t -> x_t-1 # latents = self.scheduler.step(noise_pred, timesteps[:, i].cpu(), latents, return_dict=False)[0] batch_timesteps = timesteps[:, i].cpu() prev_timestep = batch_timesteps - timestep_interval # prev_timestep = batch_timesteps - self.scheduler.config.num_train_timesteps // self.scheduler.num_inference_steps alpha_prod_t = self.scheduler.alphas_cumprod[batch_timesteps] alpha_prod_t_prev = torch.zeros_like(alpha_prod_t) for ib in range(prev_timestep.shape[0]): alpha_prod_t_prev[ib] = self.scheduler.alphas_cumprod[prev_timestep[ib]] if prev_timestep[ib] >= 0 else self.scheduler.final_alpha_cumprod beta_prod_t = 1 - alpha_prod_t alpha_prod_t = alpha_prod_t.to(device=latents.device, dtype=latents.dtype) alpha_prod_t_prev = alpha_prod_t_prev.to(device=latents.device, dtype=latents.dtype) beta_prod_t = beta_prod_t.to(device=latents.device, dtype=latents.dtype) # 3. compute predicted original sample from predicted noise also called # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf if self.scheduler.config.prediction_type == "epsilon": pred_original_sample = (latents - beta_prod_t[:,None,None,None] ** (0.5) * noise_pred) / alpha_prod_t[:, None,None,None] ** (0.5) pred_epsilon = noise_pred # elif self.scheduler.config.prediction_type == "sample": # pred_original_sample = noise_pred # pred_epsilon = (latents - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5) # elif self.scheduler.config.prediction_type == "v_prediction": # pred_original_sample = (alpha_prod_t**0.5) * latents - (beta_prod_t**0.5) * noise_pred # pred_epsilon = (alpha_prod_t**0.5) * noise_pred + (beta_prod_t**0.5) * latents else: raise ValueError( f"prediction_type given as {self.scheduler.config.prediction_type} must be one of `epsilon`, `sample`, or" " `v_prediction`" ) pred_sample_direction = (1 - alpha_prod_t_prev[:,None,None,None]) ** (0.5) * pred_epsilon latents = alpha_prod_t_prev[:,None,None,None] ** (0.5) * pred_original_sample + pred_sample_direction return latents