Spaces:
Build error
Build error
from __future__ import annotations | |
import math | |
from contextlib import contextmanager | |
from typing import Any, Union | |
import torch | |
from einops import rearrange | |
from omegaconf import ListConfig, OmegaConf | |
from pytorch_lightning import LightningModule | |
from safetensors.torch import load_file as load_safetensors | |
from torch.optim.lr_scheduler import LambdaLR | |
from ..modules import UNCONDITIONAL_CONFIG | |
from ..modules.autoencoding.temporal_ae import VideoDecoder | |
from ..modules.diffusionmodules.wrappers import OPENAIUNETWRAPPER | |
from ..modules.ema import LitEma | |
from ..util import default, disabled_train, get_obj_from_str, instantiate_from_config | |
class DiffusionEngine(LightningModule): | |
def __init__( | |
self, | |
network_config, | |
denoiser_config, | |
first_stage_config, | |
conditioner_config: Union[None, dict, ListConfig, OmegaConf] = None, | |
sampler_config: Union[None, dict, ListConfig, OmegaConf] = None, | |
optimizer_config: Union[None, dict, ListConfig, OmegaConf] = None, | |
scheduler_config: Union[None, dict, ListConfig, OmegaConf] = None, | |
loss_fn_config: Union[None, dict, ListConfig, OmegaConf] = None, | |
network_wrapper: Union[None, str] = None, | |
ckpt_path: Union[None, str] = None, | |
use_ema: bool = False, | |
ema_decay_rate: float = 0.9999, | |
scale_factor: float = 1.0, | |
disable_first_stage_autocast=False, | |
input_key: str = "img", | |
log_keys: Union[list, None] = None, | |
no_cond_log: bool = False, | |
compile_model: bool = False, | |
en_and_decode_n_samples_a_time: int = 1, | |
num_frames: int = 25, | |
slow_spatial_layers: bool = False, | |
train_peft_adapters: bool = False, | |
replace_cond_frames: bool = False, | |
fixed_cond_frames: Union[list, None] = None | |
): | |
super().__init__() | |
self.log_keys = log_keys | |
self.input_key = input_key | |
self.optimizer_config = default( | |
optimizer_config, {"target": "torch.optim.AdamW"} | |
) | |
model = instantiate_from_config(network_config) | |
self.model = get_obj_from_str( | |
default(network_wrapper, OPENAIUNETWRAPPER) | |
)( | |
model, compile_model=compile_model | |
) | |
self.denoiser = instantiate_from_config(denoiser_config) | |
self.sampler = ( | |
instantiate_from_config(sampler_config) | |
if sampler_config is not None | |
else None | |
) | |
self.conditioner = instantiate_from_config( | |
default(conditioner_config, UNCONDITIONAL_CONFIG) | |
) | |
self.scheduler_config = scheduler_config | |
self._init_first_stage(first_stage_config) | |
self.loss_fn = ( | |
instantiate_from_config(loss_fn_config) | |
if loss_fn_config is not None | |
else None | |
) | |
# if slow_spatial_layers: | |
# for n, p in self.model.named_parameters(): | |
# if "time_stack" not in n: | |
# p.requires_grad = False | |
# elif train_peft_adapters: | |
# for n, p in self.model.named_parameters(): | |
# if "adapter" not in n and p.requires_grad: | |
# p.requires_grad = False | |
self.use_ema = use_ema | |
self.ema_decay_rate = ema_decay_rate | |
if use_ema: | |
self.model_ema = LitEma(self.model, decay=ema_decay_rate) | |
print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}") | |
self.scale_factor = scale_factor | |
self.disable_first_stage_autocast = disable_first_stage_autocast | |
self.no_cond_log = no_cond_log | |
if ckpt_path is not None: | |
self.init_from_ckpt(ckpt_path) | |
self.en_and_decode_n_samples_a_time = en_and_decode_n_samples_a_time | |
self.num_frames = num_frames | |
self.slow_spatial_layers = slow_spatial_layers | |
self.train_peft_adapters = train_peft_adapters | |
self.replace_cond_frames = replace_cond_frames | |
self.fixed_cond_frames = fixed_cond_frames | |
def reinit_ema(self): | |
if self.use_ema: | |
self.model_ema = LitEma(self.model, decay=self.ema_decay_rate) | |
print(f"Reinitializing EMAs of {len(list(self.model_ema.buffers()))}") | |
def init_from_ckpt(self, path: str) -> None: | |
if path.endswith("ckpt"): | |
svd = torch.load(path, map_location="cpu")["state_dict"] | |
elif path.endswith("bin"): # for deepspeed merged checkpoints | |
svd = torch.load(path, map_location="cpu") | |
for k in list(svd.keys()): # remove the prefix | |
if "_forward_module" in k: | |
svd[k.replace("_forward_module.", "")] = svd[k] | |
del svd[k] | |
elif path.endswith("safetensors"): | |
svd = load_safetensors(path) | |
else: | |
raise NotImplementedError | |
missing, unexpected = self.load_state_dict(svd, strict=False) | |
print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys") | |
if len(missing) > 0: | |
print(f"Missing keys: {missing}") | |
if len(unexpected) > 0: | |
print(f"Unexpected keys: {unexpected}") | |
def _init_first_stage(self, config): | |
model = instantiate_from_config(config).eval() | |
model.train = disabled_train | |
for param in model.parameters(): | |
param.requires_grad = False | |
self.first_stage_model = model | |
def get_input(self, batch): | |
# assuming unified data format, dataloader returns a dict | |
# image tensors should be scaled to -1 ... 1 and in bchw format | |
input_shape = batch[self.input_key].shape | |
if len(input_shape) != 4: # is an image sequence | |
assert input_shape[1] == self.num_frames | |
batch[self.input_key] = rearrange(batch[self.input_key], "b t c h w -> (b t) c h w") | |
return batch[self.input_key] | |
def decode_first_stage(self, z, overlap=3): | |
z = z / self.scale_factor | |
n_samples = default(self.en_and_decode_n_samples_a_time, z.shape[0]) | |
all_out = list() | |
with torch.autocast("cuda", enabled=not self.disable_first_stage_autocast): | |
if overlap < n_samples: | |
previous_z = z[:overlap] | |
for current_z in z[overlap:].split(n_samples - overlap, dim=0): | |
if isinstance(self.first_stage_model.decoder, VideoDecoder): | |
kwargs = {"timesteps": current_z.shape[0] + overlap} | |
else: | |
kwargs = dict() | |
context_z = torch.cat((previous_z, current_z), dim=0) | |
previous_z = current_z[-overlap:] | |
out = self.first_stage_model.decode(context_z, **kwargs) | |
if not all_out: | |
all_out.append(out) | |
else: | |
all_out[-1][-overlap:] = (all_out[-1][-overlap:] + out[:overlap]) / 2 | |
all_out.append(out[overlap:]) | |
else: | |
for current_z in z.split(n_samples, dim=0): | |
if isinstance(self.first_stage_model.decoder, VideoDecoder): | |
kwargs = {"timesteps": current_z.shape[0]} | |
else: | |
kwargs = dict() | |
out = self.first_stage_model.decode(current_z, **kwargs) | |
all_out.append(out) | |
out = torch.cat(all_out, dim=0) | |
return out | |
def encode_first_stage(self, x): | |
n_samples = default(self.en_and_decode_n_samples_a_time, x.shape[0]) | |
n_rounds = math.ceil(x.shape[0] / n_samples) | |
all_out = list() | |
torch.cuda.synchronize() | |
print(f"Encoding {n_rounds} rounds of {n_samples} samples each") | |
with torch.autocast("cuda", enabled=not self.disable_first_stage_autocast): | |
for n in range(n_rounds): | |
torch.cuda.synchronize() | |
print("start encoding round", n) | |
out = self.first_stage_model.encode( | |
x[n * n_samples: (n + 1) * n_samples] | |
) | |
all_out.append(out) | |
torch.cuda.synchronize() | |
print("finished encoding round", n) | |
z = torch.cat(all_out, dim=0) | |
z = z * self.scale_factor | |
return z | |
def forward(self, x, batch): | |
loss = self.loss_fn(self.model, self.denoiser, self.conditioner, x, batch) # go to StandardDiffusionLoss | |
loss_mean = loss.mean() | |
loss_dict = {"loss": loss_mean} | |
return loss_mean, loss_dict | |
def shared_step(self, batch: dict) -> Any: | |
x = self.get_input(batch) | |
x = self.encode_first_stage(x) | |
batch["global_step"] = self.global_step | |
loss, loss_dict = self(x, batch) | |
return loss, loss_dict | |
def training_step(self, batch, batch_idx): | |
loss, loss_dict = self.shared_step(batch) | |
self.log_dict(loss_dict, prog_bar=True, logger=True, on_step=True, on_epoch=True) | |
self.log("global_step", self.global_step, prog_bar=True, logger=True, on_step=True, on_epoch=False) | |
if self.scheduler_config is not None: | |
lr = self.optimizers().param_groups[0]["lr"] | |
self.log("lr_abs", lr, prog_bar=True, logger=True, on_step=True, on_epoch=False) | |
return loss | |
# @torch.no_grad() | |
# def validation_step(self, batch, batch_idx): | |
# loss, loss_dict = self.shared_step(batch) | |
# self.log_dict(loss_dict, prog_bar=True, logger=True, on_step=True, on_epoch=True) | |
def test_step(self, batch, batch_idx): | |
_loss, loss_dict = self.shared_step(batch) | |
self.log_dict(loss_dict, prog_bar=True, logger=True, on_step=True, on_epoch=False) | |
def on_train_start(self, *args, **kwargs): | |
if self.sampler is None or self.loss_fn is None: | |
raise ValueError("Sampler and loss function need to be set for training") | |
def on_train_batch_end(self, *args, **kwargs): | |
if self.use_ema: | |
self.model_ema(self.model) | |
def ema_scope(self, context=None): | |
if self.use_ema: | |
self.model_ema.store(self.model.parameters()) | |
self.model_ema.copy_to(self.model) | |
if context is not None: | |
print(f"{context}: Switched to EMA weights") | |
try: | |
yield None | |
finally: | |
if self.use_ema: | |
self.model_ema.restore(self.model.parameters()) | |
if context is not None: | |
print(f"{context}: Restored training weights") | |
def instantiate_optimizer_from_config(self, params, lr, cfg): | |
return get_obj_from_str(cfg["target"])( | |
params, lr=lr, **cfg.get("params", dict()) | |
) | |
def configure_optimizers(self): | |
lr = self.learning_rate | |
if self.slow_spatial_layers: | |
param_dicts = [ | |
{ | |
"params": [p for n, p in self.model.named_parameters() if "time_stack" in n] | |
}, | |
{ | |
"params": [p for n, p in self.model.named_parameters() if "time_stack" not in n], | |
"lr": lr * 0.1 | |
} | |
] | |
elif self.train_peft_adapters: | |
param_dicts = [ | |
{ | |
"params": [p for n, p in self.model.named_parameters() if "adapter" in n] | |
} | |
] | |
else: | |
param_dicts = [ | |
{ | |
"params": list(self.model.parameters()) | |
} | |
] | |
for embedder in self.conditioner.embedders: | |
if embedder.is_trainable: | |
param_dicts.append( | |
{ | |
"params": list(embedder.parameters()) | |
} | |
) | |
opt = self.instantiate_optimizer_from_config(param_dicts, lr, self.optimizer_config) | |
if self.scheduler_config is not None: | |
scheduler = instantiate_from_config(self.scheduler_config) | |
print("Setting up LambdaLR scheduler...") | |
scheduler = [ | |
{ | |
"scheduler": LambdaLR(opt, lr_lambda=scheduler.schedule), | |
"interval": "step", | |
"frequency": 1 | |
} | |
] | |
return [opt], scheduler | |
else: | |
return opt | |
def sample( | |
self, | |
cond: dict, | |
cond_frame=None, | |
uc: Union[dict, None] = None, | |
N: int = 25, | |
shape: Union[None, tuple, list] = None, | |
**kwargs | |
): | |
randn = torch.randn(N, *shape).to(self.device) | |
cond_mask = torch.zeros(N).to(self.device) | |
if self.replace_cond_frames: | |
assert self.fixed_cond_frames | |
cond_indices = self.fixed_cond_frames | |
cond_mask = rearrange(cond_mask, "(b t) -> b t", t=self.num_frames) | |
cond_mask[:, cond_indices] = 1 | |
cond_mask = rearrange(cond_mask, "b t -> (b t)") | |
def denoiser(input, sigma, c, cond_mask): | |
return self.denoiser(self.model, input, sigma, c, cond_mask, **kwargs) | |
samples = self.sampler( # go to EulerEDMSampler | |
denoiser, randn, cond, uc=uc, cond_frame=cond_frame, cond_mask=cond_mask | |
) | |
return samples | |
def log_images( | |
self, | |
batch: dict, | |
N: int = 25, | |
sample: bool = True, | |
ucg_keys: list[str] = None, | |
**kwargs | |
) -> dict: | |
conditioner_input_keys = [e.input_key for e in self.conditioner.embedders if e.ucg_rate > 0.0] | |
if ucg_keys: | |
assert all(map(lambda x: x in conditioner_input_keys, ucg_keys)), ( | |
"Each defined ucg key for sampling must be in the provided conditioner input keys, " | |
f"but we have {ucg_keys} vs. {conditioner_input_keys}" | |
) | |
else: | |
ucg_keys = conditioner_input_keys | |
log = dict() | |
x = self.get_input(batch) | |
c, uc = self.conditioner.get_unconditional_conditioning( | |
batch, | |
force_uc_zero_embeddings=ucg_keys | |
if len(self.conditioner.embedders) > 0 | |
else list() | |
) | |
sampling_kwargs = dict() | |
N = min(x.shape[0], N) | |
x = x.to(self.device)[:N] | |
z = self.encode_first_stage(x) | |
x_reconstruct = self.decode_first_stage(z) | |
for k in c: | |
if isinstance(c[k], torch.Tensor): | |
c[k], uc[k] = map(lambda y: y[k][:N].to(self.device), (c, uc)) | |
if c[k].shape[0] < N: | |
c[k] = c[k][[0]] | |
if uc[k].shape[0] < N: | |
uc[k] = uc[k][[0]] | |
if sample: | |
with self.ema_scope("Plotting"): | |
samples = self.sample( | |
c, cond_frame=z, shape=z.shape[1:], uc=uc, N=N, **sampling_kwargs | |
) | |
samples = self.decode_first_stage(samples) | |
log["samples"] = log["samples_mp4"] = samples | |
log["inputs"] = log["inputs_mp4"] = x | |
log["targets"] = log["targets_mp4"] = x_reconstruct | |
return log | |