|
from typing import Tuple, Set, List, Dict |
|
|
|
import torch |
|
from torch import nn |
|
|
|
from model import ( |
|
ControlledUnetModel, ControlNet, |
|
AutoencoderKL, FrozenOpenCLIPEmbedder |
|
) |
|
from utils.common import sliding_windows, count_vram_usage, gaussian_weights |
|
|
|
|
|
def disabled_train(self: nn.Module) -> nn.Module: |
|
"""Overwrite model.train with this function to make sure train/eval mode |
|
does not change anymore.""" |
|
return self |
|
|
|
|
|
class ControlLDM(nn.Module): |
|
|
|
def __init__( |
|
self, |
|
unet_cfg, |
|
vae_cfg, |
|
clip_cfg, |
|
controlnet_cfg, |
|
latent_scale_factor |
|
): |
|
super().__init__() |
|
self.unet = ControlledUnetModel(**unet_cfg) |
|
self.vae = AutoencoderKL(**vae_cfg) |
|
self.clip = FrozenOpenCLIPEmbedder(**clip_cfg) |
|
self.controlnet = ControlNet(**controlnet_cfg) |
|
self.scale_factor = latent_scale_factor |
|
self.control_scales = [1.0] * 13 |
|
|
|
@torch.no_grad() |
|
def load_pretrained_sd(self, sd: Dict[str, torch.Tensor]) -> Set[str]: |
|
module_map = { |
|
"unet": "model.diffusion_model", |
|
"vae": "first_stage_model", |
|
"clip": "cond_stage_model", |
|
} |
|
modules = [("unet", self.unet), ("vae", self.vae), ("clip", self.clip)] |
|
used = set() |
|
for name, module in modules: |
|
init_sd = {} |
|
scratch_sd = module.state_dict() |
|
for key in scratch_sd: |
|
target_key = ".".join([module_map[name], key]) |
|
init_sd[key] = sd[target_key].clone() |
|
used.add(target_key) |
|
module.load_state_dict(init_sd, strict=True) |
|
unused = set(sd.keys()) - used |
|
|
|
|
|
for module in [self.vae, self.clip, self.unet]: |
|
module.eval() |
|
module.train = disabled_train |
|
for p in module.parameters(): |
|
p.requires_grad = False |
|
return unused |
|
|
|
@torch.no_grad() |
|
def load_controlnet_from_ckpt(self, sd: Dict[str, torch.Tensor]) -> None: |
|
self.controlnet.load_state_dict(sd, strict=True) |
|
|
|
@torch.no_grad() |
|
def load_controlnet_from_unet(self) -> Tuple[Set[str]]: |
|
unet_sd = self.unet.state_dict() |
|
scratch_sd = self.controlnet.state_dict() |
|
init_sd = {} |
|
init_with_new_zero = set() |
|
init_with_scratch = set() |
|
for key in scratch_sd: |
|
if key in unet_sd: |
|
this, target = scratch_sd[key], unet_sd[key] |
|
if this.size() == target.size(): |
|
init_sd[key] = target.clone() |
|
else: |
|
d_ic = this.size(1) - target.size(1) |
|
oc, _, h, w = this.size() |
|
zeros = torch.zeros((oc, d_ic, h, w), dtype=target.dtype) |
|
init_sd[key] = torch.cat((target, zeros), dim=1) |
|
init_with_new_zero.add(key) |
|
else: |
|
init_sd[key] = scratch_sd[key].clone() |
|
init_with_scratch.add(key) |
|
self.controlnet.load_state_dict(init_sd, strict=True) |
|
return init_with_new_zero, init_with_scratch |
|
|
|
def vae_encode(self, image: torch.Tensor, sample: bool=True) -> torch.Tensor: |
|
if sample: |
|
return self.vae.encode(image).sample() * self.scale_factor |
|
else: |
|
return self.vae.encode(image).mode() * self.scale_factor |
|
|
|
def vae_encode_tiled(self, image: torch.Tensor, tile_size: int, tile_stride: int, sample: bool=True) -> torch.Tensor: |
|
bs, _, h, w = image.shape |
|
z = torch.zeros((bs, 4, h // 8, w // 8), dtype=torch.float32, device=image.device) |
|
count = torch.zeros_like(z, dtype=torch.float32) |
|
weights = gaussian_weights(tile_size // 8, tile_size // 8)[None, None] |
|
weights = torch.tensor(weights, dtype=torch.float32, device=image.device) |
|
tiles = sliding_windows(h // 8, w // 8, tile_size // 8, tile_stride // 8) |
|
for hi, hi_end, wi, wi_end in tiles: |
|
tile_image = image[:, :, hi * 8:hi_end * 8, wi * 8:wi_end * 8] |
|
z[:, :, hi:hi_end, wi:wi_end] += self.vae_encode(tile_image, sample=sample) * weights |
|
count[:, :, hi:hi_end, wi:wi_end] += weights |
|
z.div_(count) |
|
return z |
|
|
|
def vae_decode(self, z: torch.Tensor) -> torch.Tensor: |
|
return self.vae.decode(z / self.scale_factor) |
|
|
|
@count_vram_usage |
|
def vae_decode_tiled(self, z: torch.Tensor, tile_size: int, tile_stride: int) -> torch.Tensor: |
|
bs, _, h, w = z.shape |
|
image = torch.zeros((bs, 3, h * 8, w * 8), dtype=torch.float32, device=z.device) |
|
count = torch.zeros_like(image, dtype=torch.float32) |
|
weights = gaussian_weights(tile_size * 8, tile_size * 8)[None, None] |
|
weights = torch.tensor(weights, dtype=torch.float32, device=z.device) |
|
tiles = sliding_windows(h, w, tile_size, tile_stride) |
|
for hi, hi_end, wi, wi_end in tiles: |
|
tile_z = z[:, :, hi:hi_end, wi:wi_end] |
|
image[:, :, hi * 8:hi_end * 8, wi * 8:wi_end * 8] += self.vae_decode(tile_z) * weights |
|
count[:, :, hi * 8:hi_end * 8, wi * 8:wi_end * 8] += weights |
|
image.div_(count) |
|
return image |
|
|
|
def prepare_condition(self, clean: torch.Tensor, txt: List[str]) -> Dict[str, torch.Tensor]: |
|
return dict( |
|
c_txt=self.clip.encode(txt), |
|
c_img=self.vae_encode(clean * 2 - 1, sample=False) |
|
) |
|
|
|
@count_vram_usage |
|
def prepare_condition_tiled(self, clean: torch.Tensor, txt: List[str], tile_size: int, tile_stride: int) -> Dict[str, torch.Tensor]: |
|
return dict( |
|
c_txt=self.clip.encode(txt), |
|
c_img=self.vae_encode_tiled(clean * 2 - 1, tile_size, tile_stride, sample=False) |
|
) |
|
|
|
def forward(self, x_noisy, t, cond): |
|
c_txt = cond["c_txt"] |
|
c_img = cond["c_img"] |
|
control = self.controlnet( |
|
x=x_noisy, hint=c_img, |
|
timesteps=t, context=c_txt |
|
) |
|
control = [c * scale for c, scale in zip(control, self.control_scales)] |
|
eps = self.unet( |
|
x=x_noisy, timesteps=t, |
|
context=c_txt, control=control, only_mid_control=False |
|
) |
|
return eps |
|
|