from __future__ import annotations import math from contextlib import contextmanager from typing import Any, Union import torch from einops import rearrange from omegaconf import ListConfig, OmegaConf from pytorch_lightning import LightningModule from safetensors.torch import load_file as load_safetensors from torch.optim.lr_scheduler import LambdaLR from ..modules import UNCONDITIONAL_CONFIG from ..modules.autoencoding.temporal_ae import VideoDecoder from ..modules.diffusionmodules.wrappers import OPENAIUNETWRAPPER from ..modules.ema import LitEma from ..util import default, disabled_train, get_obj_from_str, instantiate_from_config class DiffusionEngine(LightningModule): def __init__( self, network_config, denoiser_config, first_stage_config, conditioner_config: Union[None, dict, ListConfig, OmegaConf] = None, sampler_config: Union[None, dict, ListConfig, OmegaConf] = None, optimizer_config: Union[None, dict, ListConfig, OmegaConf] = None, scheduler_config: Union[None, dict, ListConfig, OmegaConf] = None, loss_fn_config: Union[None, dict, ListConfig, OmegaConf] = None, network_wrapper: Union[None, str] = None, ckpt_path: Union[None, str] = None, use_ema: bool = False, ema_decay_rate: float = 0.9999, scale_factor: float = 1.0, disable_first_stage_autocast=False, input_key: str = "img", log_keys: Union[list, None] = None, no_cond_log: bool = False, compile_model: bool = False, en_and_decode_n_samples_a_time: int = 1, num_frames: int = 25, slow_spatial_layers: bool = False, train_peft_adapters: bool = False, replace_cond_frames: bool = False, fixed_cond_frames: Union[list, None] = None ): super().__init__() self.log_keys = log_keys self.input_key = input_key self.optimizer_config = default( optimizer_config, {"target": "torch.optim.AdamW"} ) model = instantiate_from_config(network_config) self.model = get_obj_from_str( default(network_wrapper, OPENAIUNETWRAPPER) )( model, compile_model=compile_model ) self.denoiser = instantiate_from_config(denoiser_config) self.sampler = ( instantiate_from_config(sampler_config) if sampler_config is not None else None ) self.conditioner = instantiate_from_config( default(conditioner_config, UNCONDITIONAL_CONFIG) ) self.scheduler_config = scheduler_config self._init_first_stage(first_stage_config) self.loss_fn = ( instantiate_from_config(loss_fn_config) if loss_fn_config is not None else None ) # if slow_spatial_layers: # for n, p in self.model.named_parameters(): # if "time_stack" not in n: # p.requires_grad = False # elif train_peft_adapters: # for n, p in self.model.named_parameters(): # if "adapter" not in n and p.requires_grad: # p.requires_grad = False self.use_ema = use_ema self.ema_decay_rate = ema_decay_rate if use_ema: self.model_ema = LitEma(self.model, decay=ema_decay_rate) print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}") self.scale_factor = scale_factor self.disable_first_stage_autocast = disable_first_stage_autocast self.no_cond_log = no_cond_log if ckpt_path is not None: self.init_from_ckpt(ckpt_path) self.en_and_decode_n_samples_a_time = en_and_decode_n_samples_a_time self.num_frames = num_frames self.slow_spatial_layers = slow_spatial_layers self.train_peft_adapters = train_peft_adapters self.replace_cond_frames = replace_cond_frames self.fixed_cond_frames = fixed_cond_frames def reinit_ema(self): if self.use_ema: self.model_ema = LitEma(self.model, decay=self.ema_decay_rate) print(f"Reinitializing EMAs of {len(list(self.model_ema.buffers()))}") def init_from_ckpt(self, path: str) -> None: if path.endswith("ckpt"): svd = torch.load(path, map_location="cpu")["state_dict"] elif path.endswith("bin"): # for deepspeed merged checkpoints svd = torch.load(path, map_location="cpu") for k in list(svd.keys()): # remove the prefix if "_forward_module" in k: svd[k.replace("_forward_module.", "")] = svd[k] del svd[k] elif path.endswith("safetensors"): svd = load_safetensors(path) else: raise NotImplementedError missing, unexpected = self.load_state_dict(svd, strict=False) print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys") if len(missing) > 0: print(f"Missing keys: {missing}") if len(unexpected) > 0: print(f"Unexpected keys: {unexpected}") def _init_first_stage(self, config): model = instantiate_from_config(config).eval() model.train = disabled_train for param in model.parameters(): param.requires_grad = False self.first_stage_model = model def get_input(self, batch): # assuming unified data format, dataloader returns a dict # image tensors should be scaled to -1 ... 1 and in bchw format input_shape = batch[self.input_key].shape if len(input_shape) != 4: # is an image sequence assert input_shape[1] == self.num_frames batch[self.input_key] = rearrange(batch[self.input_key], "b t c h w -> (b t) c h w") return batch[self.input_key] @torch.no_grad() def decode_first_stage(self, z, overlap=3): z = z / self.scale_factor n_samples = default(self.en_and_decode_n_samples_a_time, z.shape[0]) all_out = list() with torch.autocast("cuda", enabled=not self.disable_first_stage_autocast): if overlap < n_samples: previous_z = z[:overlap] for current_z in z[overlap:].split(n_samples - overlap, dim=0): if isinstance(self.first_stage_model.decoder, VideoDecoder): kwargs = {"timesteps": current_z.shape[0] + overlap} else: kwargs = dict() context_z = torch.cat((previous_z, current_z), dim=0) previous_z = current_z[-overlap:] out = self.first_stage_model.decode(context_z, **kwargs) if not all_out: all_out.append(out) else: all_out[-1][-overlap:] = (all_out[-1][-overlap:] + out[:overlap]) / 2 all_out.append(out[overlap:]) else: for current_z in z.split(n_samples, dim=0): if isinstance(self.first_stage_model.decoder, VideoDecoder): kwargs = {"timesteps": current_z.shape[0]} else: kwargs = dict() out = self.first_stage_model.decode(current_z, **kwargs) all_out.append(out) out = torch.cat(all_out, dim=0) return out @torch.no_grad() def encode_first_stage(self, x): n_samples = default(self.en_and_decode_n_samples_a_time, x.shape[0]) n_rounds = math.ceil(x.shape[0] / n_samples) all_out = list() torch.cuda.synchronize() print(f"Encoding {n_rounds} rounds of {n_samples} samples each") with torch.autocast("cuda", enabled=not self.disable_first_stage_autocast): for n in range(n_rounds): torch.cuda.synchronize() print("start encoding round", n) out = self.first_stage_model.encode( x[n * n_samples: (n + 1) * n_samples] ) all_out.append(out) torch.cuda.synchronize() print("finished encoding round", n) z = torch.cat(all_out, dim=0) z = z * self.scale_factor return z def forward(self, x, batch): loss = self.loss_fn(self.model, self.denoiser, self.conditioner, x, batch) # go to StandardDiffusionLoss loss_mean = loss.mean() loss_dict = {"loss": loss_mean} return loss_mean, loss_dict def shared_step(self, batch: dict) -> Any: x = self.get_input(batch) x = self.encode_first_stage(x) batch["global_step"] = self.global_step loss, loss_dict = self(x, batch) return loss, loss_dict def training_step(self, batch, batch_idx): loss, loss_dict = self.shared_step(batch) self.log_dict(loss_dict, prog_bar=True, logger=True, on_step=True, on_epoch=True) self.log("global_step", self.global_step, prog_bar=True, logger=True, on_step=True, on_epoch=False) if self.scheduler_config is not None: lr = self.optimizers().param_groups[0]["lr"] self.log("lr_abs", lr, prog_bar=True, logger=True, on_step=True, on_epoch=False) return loss # @torch.no_grad() # def validation_step(self, batch, batch_idx): # loss, loss_dict = self.shared_step(batch) # self.log_dict(loss_dict, prog_bar=True, logger=True, on_step=True, on_epoch=True) @torch.no_grad() def test_step(self, batch, batch_idx): _loss, loss_dict = self.shared_step(batch) self.log_dict(loss_dict, prog_bar=True, logger=True, on_step=True, on_epoch=False) def on_train_start(self, *args, **kwargs): if self.sampler is None or self.loss_fn is None: raise ValueError("Sampler and loss function need to be set for training") def on_train_batch_end(self, *args, **kwargs): if self.use_ema: self.model_ema(self.model) @contextmanager def ema_scope(self, context=None): if self.use_ema: self.model_ema.store(self.model.parameters()) self.model_ema.copy_to(self.model) if context is not None: print(f"{context}: Switched to EMA weights") try: yield None finally: if self.use_ema: self.model_ema.restore(self.model.parameters()) if context is not None: print(f"{context}: Restored training weights") def instantiate_optimizer_from_config(self, params, lr, cfg): return get_obj_from_str(cfg["target"])( params, lr=lr, **cfg.get("params", dict()) ) def configure_optimizers(self): lr = self.learning_rate if self.slow_spatial_layers: param_dicts = [ { "params": [p for n, p in self.model.named_parameters() if "time_stack" in n] }, { "params": [p for n, p in self.model.named_parameters() if "time_stack" not in n], "lr": lr * 0.1 } ] elif self.train_peft_adapters: param_dicts = [ { "params": [p for n, p in self.model.named_parameters() if "adapter" in n] } ] else: param_dicts = [ { "params": list(self.model.parameters()) } ] for embedder in self.conditioner.embedders: if embedder.is_trainable: param_dicts.append( { "params": list(embedder.parameters()) } ) opt = self.instantiate_optimizer_from_config(param_dicts, lr, self.optimizer_config) if self.scheduler_config is not None: scheduler = instantiate_from_config(self.scheduler_config) print("Setting up LambdaLR scheduler...") scheduler = [ { "scheduler": LambdaLR(opt, lr_lambda=scheduler.schedule), "interval": "step", "frequency": 1 } ] return [opt], scheduler else: return opt @torch.no_grad() def sample( self, cond: dict, cond_frame=None, uc: Union[dict, None] = None, N: int = 25, shape: Union[None, tuple, list] = None, **kwargs ): randn = torch.randn(N, *shape).to(self.device) cond_mask = torch.zeros(N).to(self.device) if self.replace_cond_frames: assert self.fixed_cond_frames cond_indices = self.fixed_cond_frames cond_mask = rearrange(cond_mask, "(b t) -> b t", t=self.num_frames) cond_mask[:, cond_indices] = 1 cond_mask = rearrange(cond_mask, "b t -> (b t)") def denoiser(input, sigma, c, cond_mask): return self.denoiser(self.model, input, sigma, c, cond_mask, **kwargs) samples = self.sampler( # go to EulerEDMSampler denoiser, randn, cond, uc=uc, cond_frame=cond_frame, cond_mask=cond_mask ) return samples @torch.no_grad() def log_images( self, batch: dict, N: int = 25, sample: bool = True, ucg_keys: list[str] = None, **kwargs ) -> dict: conditioner_input_keys = [e.input_key for e in self.conditioner.embedders if e.ucg_rate > 0.0] if ucg_keys: assert all(map(lambda x: x in conditioner_input_keys, ucg_keys)), ( "Each defined ucg key for sampling must be in the provided conditioner input keys, " f"but we have {ucg_keys} vs. {conditioner_input_keys}" ) else: ucg_keys = conditioner_input_keys log = dict() x = self.get_input(batch) c, uc = self.conditioner.get_unconditional_conditioning( batch, force_uc_zero_embeddings=ucg_keys if len(self.conditioner.embedders) > 0 else list() ) sampling_kwargs = dict() N = min(x.shape[0], N) x = x.to(self.device)[:N] z = self.encode_first_stage(x) x_reconstruct = self.decode_first_stage(z) for k in c: if isinstance(c[k], torch.Tensor): c[k], uc[k] = map(lambda y: y[k][:N].to(self.device), (c, uc)) if c[k].shape[0] < N: c[k] = c[k][[0]] if uc[k].shape[0] < N: uc[k] = uc[k][[0]] if sample: with self.ema_scope("Plotting"): samples = self.sample( c, cond_frame=z, shape=z.shape[1:], uc=uc, N=N, **sampling_kwargs ) samples = self.decode_first_stage(samples) log["samples"] = log["samples_mp4"] = samples log["inputs"] = log["inputs_mp4"] = x log["targets"] = log["targets_mp4"] = x_reconstruct return log