|
from typing import Optional, Tuple, Dict |
|
|
|
import torch |
|
from torch import nn |
|
import numpy as np |
|
from tqdm import tqdm |
|
|
|
from model.gaussian_diffusion import extract_into_tensor |
|
from model.cldm import ControlLDM |
|
from utils.cond_fn import Guidance |
|
from utils.common import sliding_windows, gaussian_weights |
|
|
|
|
|
|
|
def space_timesteps(num_timesteps, section_counts): |
|
""" |
|
Create a list of timesteps to use from an original diffusion process, |
|
given the number of timesteps we want to take from equally-sized portions |
|
of the original process. |
|
For example, if there's 300 timesteps and the section counts are [10,15,20] |
|
then the first 100 timesteps are strided to be 10 timesteps, the second 100 |
|
are strided to be 15 timesteps, and the final 100 are strided to be 20. |
|
If the stride is a string starting with "ddim", then the fixed striding |
|
from the DDIM paper is used, and only one section is allowed. |
|
:param num_timesteps: the number of diffusion steps in the original |
|
process to divide up. |
|
:param section_counts: either a list of numbers, or a string containing |
|
comma-separated numbers, indicating the step count |
|
per section. As a special case, use "ddimN" where N |
|
is a number of steps to use the striding from the |
|
DDIM paper. |
|
:return: a set of diffusion steps from the original process to use. |
|
""" |
|
if isinstance(section_counts, str): |
|
if section_counts.startswith("ddim"): |
|
desired_count = int(section_counts[len("ddim") :]) |
|
for i in range(1, num_timesteps): |
|
if len(range(0, num_timesteps, i)) == desired_count: |
|
return set(range(0, num_timesteps, i)) |
|
raise ValueError( |
|
f"cannot create exactly {num_timesteps} steps with an integer stride" |
|
) |
|
section_counts = [int(x) for x in section_counts.split(",")] |
|
size_per = num_timesteps // len(section_counts) |
|
extra = num_timesteps % len(section_counts) |
|
start_idx = 0 |
|
all_steps = [] |
|
for i, section_count in enumerate(section_counts): |
|
size = size_per + (1 if i < extra else 0) |
|
if size < section_count: |
|
raise ValueError( |
|
f"cannot divide section of {size} steps into {section_count}" |
|
) |
|
if section_count <= 1: |
|
frac_stride = 1 |
|
else: |
|
frac_stride = (size - 1) / (section_count - 1) |
|
cur_idx = 0.0 |
|
taken_steps = [] |
|
for _ in range(section_count): |
|
taken_steps.append(start_idx + round(cur_idx)) |
|
cur_idx += frac_stride |
|
all_steps += taken_steps |
|
start_idx += size |
|
return set(all_steps) |
|
|
|
|
|
class SpacedSampler(nn.Module): |
|
""" |
|
Implementation for spaced sampling schedule proposed in IDDPM. This class is designed |
|
for sampling ControlLDM. |
|
|
|
https://arxiv.org/pdf/2102.09672.pdf |
|
""" |
|
|
|
def __init__(self, betas: np.ndarray) -> "SpacedSampler": |
|
super().__init__() |
|
self.num_timesteps = len(betas) |
|
self.original_betas = betas |
|
self.original_alphas_cumprod = np.cumprod(1.0 - betas, axis=0) |
|
self.context = {} |
|
|
|
def register(self, name: str, value: np.ndarray) -> None: |
|
self.register_buffer(name, torch.tensor(value, dtype=torch.float32)) |
|
|
|
def make_schedule(self, num_steps: int) -> None: |
|
|
|
|
|
used_timesteps = space_timesteps(self.num_timesteps, str(num_steps)) |
|
betas = [] |
|
last_alpha_cumprod = 1.0 |
|
for i, alpha_cumprod in enumerate(self.original_alphas_cumprod): |
|
if i in used_timesteps: |
|
|
|
betas.append(1 - alpha_cumprod / last_alpha_cumprod) |
|
last_alpha_cumprod = alpha_cumprod |
|
assert len(betas) == num_steps |
|
self.timesteps = np.array(sorted(list(used_timesteps)), dtype=np.int32) |
|
|
|
betas = np.array(betas, dtype=np.float64) |
|
alphas = 1.0 - betas |
|
alphas_cumprod = np.cumprod(alphas, axis=0) |
|
|
|
alphas_cumprod_prev = np.append(1.0, alphas_cumprod[:-1]) |
|
sqrt_recip_alphas_cumprod = np.sqrt(1.0 / alphas_cumprod) |
|
sqrt_recipm1_alphas_cumprod = np.sqrt(1.0 / alphas_cumprod - 1) |
|
|
|
posterior_variance = ( |
|
betas * (1.0 - alphas_cumprod_prev) / (1.0 - alphas_cumprod) |
|
) |
|
|
|
|
|
posterior_log_variance_clipped = np.log( |
|
np.append(posterior_variance[1], posterior_variance[1:]) |
|
) |
|
posterior_mean_coef1 = ( |
|
betas * np.sqrt(alphas_cumprod_prev) / (1.0 - alphas_cumprod) |
|
) |
|
posterior_mean_coef2 = ( |
|
(1.0 - alphas_cumprod_prev) |
|
* np.sqrt(alphas) |
|
/ (1.0 - alphas_cumprod) |
|
) |
|
|
|
self.register("sqrt_recip_alphas_cumprod", sqrt_recip_alphas_cumprod) |
|
self.register("sqrt_recipm1_alphas_cumprod", sqrt_recipm1_alphas_cumprod) |
|
self.register("posterior_variance", posterior_variance) |
|
self.register("posterior_log_variance_clipped", posterior_log_variance_clipped) |
|
self.register("posterior_mean_coef1", posterior_mean_coef1) |
|
self.register("posterior_mean_coef2", posterior_mean_coef2) |
|
|
|
def q_posterior_mean_variance(self, x_start: torch.Tensor, x_t: torch.Tensor, t: torch.Tensor) -> Tuple[torch.Tensor]: |
|
""" |
|
Implement the posterior distribution q(x_{t-1}|x_t, x_0). |
|
|
|
Args: |
|
x_start (torch.Tensor): The predicted images (NCHW) in timestep `t`. |
|
x_t (torch.Tensor): The sampled intermediate variables (NCHW) of timestep `t`. |
|
t (torch.Tensor): Timestep (N) of `x_t`. `t` serves as an index to get |
|
parameters for each timestep. |
|
|
|
Returns: |
|
posterior_mean (torch.Tensor): Mean of the posterior distribution. |
|
posterior_variance (torch.Tensor): Variance of the posterior distribution. |
|
posterior_log_variance_clipped (torch.Tensor): Log variance of the posterior distribution. |
|
""" |
|
posterior_mean = ( |
|
extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start |
|
+ extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t |
|
) |
|
posterior_variance = extract_into_tensor(self.posterior_variance, t, x_t.shape) |
|
posterior_log_variance_clipped = extract_into_tensor( |
|
self.posterior_log_variance_clipped, t, x_t.shape |
|
) |
|
return posterior_mean, posterior_variance, posterior_log_variance_clipped |
|
|
|
def _predict_xstart_from_eps(self, x_t: torch.Tensor, t: torch.Tensor, eps: torch.Tensor) -> torch.Tensor: |
|
return ( |
|
extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t |
|
- extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * eps |
|
) |
|
|
|
def apply_cond_fn( |
|
self, |
|
model: ControlLDM, |
|
pred_x0: torch.Tensor, |
|
t: torch.Tensor, |
|
index: torch.Tensor, |
|
cond_fn: Guidance |
|
) -> torch.Tensor: |
|
t_now = int(t[0].item()) + 1 |
|
if not (cond_fn.t_stop < t_now and t_now < cond_fn.t_start): |
|
|
|
self.context["g_apply"] = False |
|
return pred_x0 |
|
grad_rescale = 1 / extract_into_tensor(self.posterior_mean_coef1, index, pred_x0.shape) |
|
|
|
loss_vals = [] |
|
for _ in range(cond_fn.repeat): |
|
|
|
target, pred = None, None |
|
if cond_fn.space == "latent": |
|
target = model.vae_encode(cond_fn.target) |
|
pred = pred_x0 |
|
elif cond_fn.space == "rgb": |
|
|
|
|
|
with torch.enable_grad(): |
|
target = cond_fn.target |
|
pred_x0_rg = pred_x0.detach().clone().requires_grad_(True) |
|
pred = model.vae_decode(pred_x0_rg) |
|
assert pred.requires_grad |
|
else: |
|
raise NotImplementedError(cond_fn.space) |
|
|
|
delta_pred, loss_val = cond_fn(target, pred, t_now) |
|
loss_vals.append(loss_val) |
|
|
|
if cond_fn.space == "latent": |
|
delta_pred_x0 = delta_pred |
|
pred_x0 = pred_x0 + delta_pred_x0 * grad_rescale |
|
elif cond_fn.space == "rgb": |
|
pred.backward(delta_pred) |
|
delta_pred_x0 = pred_x0_rg.grad |
|
pred_x0 = pred_x0 + delta_pred_x0 * grad_rescale |
|
else: |
|
raise NotImplementedError(cond_fn.space) |
|
self.context["g_apply"] = True |
|
self.context["g_loss"] = float(np.mean(loss_vals)) |
|
return pred_x0 |
|
|
|
def predict_noise( |
|
self, |
|
model: ControlLDM, |
|
x: torch.Tensor, |
|
t: torch.Tensor, |
|
cond: Dict[str, torch.Tensor], |
|
uncond: Optional[Dict[str, torch.Tensor]], |
|
cfg_scale: float |
|
) -> torch.Tensor: |
|
if uncond is None or cfg_scale == 1.: |
|
model_output = model(x, t, cond) |
|
else: |
|
|
|
model_cond = model(x, t, cond) |
|
model_uncond = model(x, t, uncond) |
|
model_output = model_uncond + cfg_scale * (model_cond - model_uncond) |
|
return model_output |
|
|
|
@torch.no_grad() |
|
def predict_noise_tiled( |
|
self, |
|
model: ControlLDM, |
|
x: torch.Tensor, |
|
t: torch.Tensor, |
|
cond: Dict[str, torch.Tensor], |
|
uncond: Optional[Dict[str, torch.Tensor]], |
|
cfg_scale: float, |
|
tile_size: int, |
|
tile_stride: int |
|
): |
|
_, _, h, w = x.shape |
|
tiles = tqdm(sliding_windows(h, w, tile_size // 8, tile_stride // 8), unit="tile", leave=False) |
|
eps = torch.zeros_like(x) |
|
count = torch.zeros_like(x, dtype=torch.float32) |
|
weights = gaussian_weights(tile_size // 8, tile_size // 8)[None, None] |
|
weights = torch.tensor(weights, dtype=torch.float32, device=x.device) |
|
for hi, hi_end, wi, wi_end in tiles: |
|
tiles.set_description(f"Process tile ({hi} {hi_end}), ({wi} {wi_end})") |
|
tile_x = x[:, :, hi:hi_end, wi:wi_end] |
|
tile_cond = { |
|
"c_img": cond["c_img"][:, :, hi:hi_end, wi:wi_end], |
|
"c_txt": cond["c_txt"] |
|
} |
|
if uncond: |
|
tile_uncond = { |
|
"c_img": uncond["c_img"][:, :, hi:hi_end, wi:wi_end], |
|
"c_txt": uncond["c_txt"] |
|
} |
|
tile_eps = self.predict_noise(model, tile_x, t, tile_cond, tile_uncond, cfg_scale) |
|
|
|
eps[:, :, hi:hi_end, wi:wi_end] += tile_eps * weights |
|
count[:, :, hi:hi_end, wi:wi_end] += weights |
|
|
|
eps.div_(count) |
|
return eps |
|
|
|
@torch.no_grad() |
|
def p_sample( |
|
self, |
|
model: ControlLDM, |
|
x: torch.Tensor, |
|
t: torch.Tensor, |
|
index: torch.Tensor, |
|
cond: Dict[str, torch.Tensor], |
|
uncond: Optional[Dict[str, torch.Tensor]], |
|
cfg_scale: float, |
|
cond_fn: Optional[Guidance], |
|
tiled: bool, |
|
tile_size: int, |
|
tile_stride: int |
|
) -> torch.Tensor: |
|
if tiled: |
|
eps = self.predict_noise_tiled(model, x, t, cond, uncond, cfg_scale, tile_size, tile_stride) |
|
else: |
|
eps = self.predict_noise(model, x, t, cond, uncond, cfg_scale) |
|
pred_x0 = self._predict_xstart_from_eps(x, index, eps) |
|
if cond_fn: |
|
assert not tiled, f"tiled sampling currently doesn't support guidance" |
|
pred_x0 = self.apply_cond_fn(model, pred_x0, t, index, cond_fn) |
|
model_mean, model_variance, _ = self.q_posterior_mean_variance(pred_x0, x, index) |
|
noise = torch.randn_like(x) |
|
nonzero_mask = ( |
|
(index != 0).float().view(-1, *([1] * (len(x.shape) - 1))) |
|
) |
|
x_prev = model_mean + nonzero_mask * torch.sqrt(model_variance) * noise |
|
return x_prev |
|
|
|
@torch.no_grad() |
|
def sample( |
|
self, |
|
model: ControlLDM, |
|
device: str, |
|
steps: int, |
|
batch_size: int, |
|
x_size: Tuple[int], |
|
cond: Dict[str, torch.Tensor], |
|
uncond: Dict[str, torch.Tensor], |
|
cfg_scale: float, |
|
cond_fn: Optional[Guidance]=None, |
|
tiled: bool=False, |
|
tile_size: int=-1, |
|
tile_stride: int=-1, |
|
x_T: Optional[torch.Tensor]=None, |
|
progress: bool=True, |
|
progress_leave: bool=True, |
|
) -> torch.Tensor: |
|
self.make_schedule(steps) |
|
self.to(device) |
|
if x_T is None: |
|
|
|
img = torch.randn((batch_size, *x_size), device=device) |
|
else: |
|
img = x_T |
|
timesteps = np.flip(self.timesteps) |
|
total_steps = len(self.timesteps) |
|
iterator = tqdm(timesteps, total=total_steps, leave=progress_leave, disable=not progress) |
|
for i, step in enumerate(iterator): |
|
ts = torch.full((batch_size,), step, device=device, dtype=torch.long) |
|
index = torch.full_like(ts, fill_value=total_steps - i - 1) |
|
img = self.p_sample( |
|
model, img, ts, index, cond, uncond, cfg_scale, cond_fn, |
|
tiled, tile_size, tile_stride |
|
) |
|
if cond_fn and self.context["g_apply"]: |
|
loss_val = self.context["g_loss"] |
|
desc = f"Spaced Sampler With Guidance, Loss: {loss_val:.6f}" |
|
else: |
|
desc = "Spaced Sampler" |
|
iterator.set_description(desc) |
|
return img |
|
|